Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncation of sequences that are beyond the model's maximum length #359

Open
MootezSaaD opened this issue Jan 14, 2024 · 2 comments
Open
Labels
feat/tokenization Feature: Tokenization/piecer type/bug Type: Bug type/feature Type: Feature

Comments

@MootezSaaD
Copy link

Hi,
First, I would like to thank you for this library :-) I'm really enjoying it.

I tried to tokenize a sequence with around 4K tokens and then fed it to a RoBERTa-based model (CodeBERT). This led to the following issue,

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-12-a373b5333f39>](https://localhost:8080/#) in <cell line: 1>()
      2    ids = input_sentence.padded_tensor(padding_id=0, pad_left=True)
      3    mask = input.attention_mask(pad_left=True)
----> 4    model_output = encoder(piece_ids=ids, attention_mask=mask)

10 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, attention_mask, positions, type_ids)
    122         type_ids: Optional[Tensor] = None,
    123     ) -> ModelOutput:
--> 124         embeddings = self.embeddings(piece_ids, positions=positions, type_ids=type_ids)
    125         layer_output = embeddings
    126 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/roberta/embeddings.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
     96         if positions is None:
     97             positions = self._get_positions(piece_ids)
---> 98         return super().forward(
     99             piece_ids,
    100             positions=positions,

[/usr/local/lib/python3.10/dist-packages/curated_transformers/layers/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
    180             if positions is None:
    181                 positions = self._get_positions(piece_ids)
--> 182             position_embeddings = self.position_embeddings(positions)
    183             embeddings += position_embeddings
    184 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py](https://localhost:8080/#) in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2231         # remove once script supports set_grad_enabled
   2232         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2233     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2234 
   2235 

IndexError: index out of range in self

For reference, here was the code that I was using,

MODEL_TAG = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_hf_hub(name=MODEL_TAG, revision="main")
model = RoBERTaEncoder.from_hf_hub(
    name=MODEL_TAG,
    revision="main",
)
code = [
   'void avcodec_string(char *buf, int buf_size, AVCodecContext *enc, int encode)\n\n{\n\n    const char *codec_type;\n\n    const char *codec_name;\n\n    const char *profile = NULL;\n\n    const AVCodec *p;\n\n    int64_t bitrate;\n\n    int new_line = 0;\n\n    AVRational display_aspect_ratio;\n\n    const char *separator = enc->dump_separator ? (const char *)enc->dump_separator : ", ";\n\n\n\n    if (!buf || buf_size <= 0)\n\n        return;\n\n    codec_type = av_get_media_type_string(enc->codec_type);\n\n    codec_name = avcodec_get_name(enc->codec_id);\n\n    if (enc->profile != FF_PROFILE_UNKNOWN) {\n\n        if (enc->codec)\n\n            p = enc->codec;\n\n        else\n\n            p = encode ? avcodec_find_encoder(enc->codec_id) :\n\n                        avcodec_find_decoder(enc->codec_id);\n\n        if (p)\n\n            profile = av_get_profile_name(p, enc->profile);\n\n    }\n\n\n\n    snprintf(buf, buf_size, "%s: %s", codec_type ? codec_type : "unknown",\n\n             codec_name);\n\n    buf[0] ^= \'a\' ^ \'A\'; /* first letter in uppercase */\n\n\n\n    if (enc->codec && strcmp(enc->codec->name, codec_name))\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", enc->codec->name);\n\n\n\n    if (profile)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", profile);\n\n    if (   enc->codec_type == AVMEDIA_TYPE_VIDEO\n\n        && av_log_get_level() >= AV_LOG_VERBOSE\n\n        && enc->refs)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %d reference frame%s",\n\n                 enc->refs, enc->refs > 1 ? "s" : "");\n\n\n\n    if (enc->codec_tag) {\n\n        char tag_buf[32];\n\n        av_get_codec_tag_string(tag_buf, sizeof(tag_buf), enc->codec_tag);\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 " (%s / 0x%04X)", tag_buf, enc->codec_tag);\n\n    }\n\n\n\n    switch (enc->codec_type) {\n\n    case AVMEDIA_TYPE_VIDEO:\n\n        {\n\n            char detail[256] = "(";\n\n\n\n            av_strlcat(buf, separator, buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 "%s", enc->pix_fmt == AV_PIX_FMT_NONE ? "none" :\n\n                     av_get_pix_fmt_name(enc->pix_fmt));\n\n            if (enc->bits_per_raw_sample && enc->pix_fmt != AV_PIX_FMT_NONE &&\n\n                enc->bits_per_raw_sample < av_pix_fmt_desc_get(enc->pix_fmt)->comp[0].depth)\n\n                av_strlcatf(detail, sizeof(detail), "%d bpc, ", enc->bits_per_raw_sample);\n\n            if (enc->color_range != AVCOL_RANGE_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_color_range_name(enc->color_range));\n\n\n\n            if (enc->colorspace != AVCOL_SPC_UNSPECIFIED ||\n\n                enc->color_primaries != AVCOL_PRI_UNSPECIFIED ||\n\n                enc->color_trc != AVCOL_TRC_UNSPECIFIED) {\n\n                if (enc->colorspace != (int)enc->color_primaries ||\n\n                    enc->colorspace != (int)enc->color_trc) {\n\n                    new_line = 1;\n\n                    av_strlcatf(detail, sizeof(detail), "%s/%s/%s, ",\n\n                                av_color_space_name(enc->colorspace),\n\n                                av_color_primaries_name(enc->color_primaries),\n\n                                av_color_transfer_name(enc->color_trc));\n\n                } else\n\n                    av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                                av_get_colorspace_name(enc->colorspace));\n\n            }\n\n\n\n            if (av_log_get_level() >= AV_LOG_DEBUG &&\n\n                enc->chroma_sample_location != AVCHROMA_LOC_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_chroma_location_name(enc->chroma_sample_location));\n\n\n\n            if (strlen(detail) > 1) {\n\n                detail[strlen(detail) - 2] = 0;\n\n                av_strlcatf(buf, buf_size, "%s)", detail);\n\n            }\n\n        }\n\n\n\n        if (enc->width) {\n\n            av_strlcat(buf, new_line ? separator : ", ", buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%dx%d",\n\n                     enc->width, enc->height);\n\n\n\n            if (av_log_get_level() >= AV_LOG_VERBOSE &&\n\n                (enc->width != enc->coded_width ||\n\n                 enc->height != enc->coded_height))\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " (%dx%d)", enc->coded_width, enc->coded_height);\n\n\n\n            if (enc->sample_aspect_ratio.num) {\n\n                av_reduce(&display_aspect_ratio.num, &display_aspect_ratio.den,\n\n                          enc->width * enc->sample_aspect_ratio.num,\n\n                          enc->height * enc->sample_aspect_ratio.den,\n\n                          1024 * 1024);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " [SAR %d:%d DAR %d:%d]",\n\n                         enc->sample_aspect_ratio.num, enc->sample_aspect_ratio.den,\n\n                         display_aspect_ratio.num, display_aspect_ratio.den);\n\n            }\n\n            if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n                int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n            }\n\n        }\n\n        if (encode) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", q=%d-%d", enc->qmin, enc->qmax);\n\n        } else {\n\n            if (enc->properties & FF_CODEC_PROPERTY_CLOSED_CAPTIONS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", Closed Captions");\n\n            if (enc->properties & FF_CODEC_PROPERTY_LOSSLESS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", lossless");\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_AUDIO:\n\n        av_strlcat(buf, separator, buf_size);\n\n\n\n        if (enc->sample_rate) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%d Hz, ", enc->sample_rate);\n\n        }\n\n        av_get_channel_layout_string(buf + strlen(buf), buf_size - strlen(buf), enc->channels, enc->channel_layout);\n\n        if (enc->sample_fmt != AV_SAMPLE_FMT_NONE) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %s", av_get_sample_fmt_name(enc->sample_fmt));\n\n        }\n\n        if (   enc->bits_per_raw_sample > 0\n\n            && enc->bits_per_raw_sample != av_get_bytes_per_sample(enc->sample_fmt) * 8)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     " (%d bit)", enc->bits_per_raw_sample);\n\n        break;\n\n    case AVMEDIA_TYPE_DATA:\n\n        if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n            int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n            if (g)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_SUBTITLE:\n\n        if (enc->width)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %dx%d", enc->width, enc->height);\n\n        break;\n\n    default:\n\n        return;\n\n    }\n\n    if (encode) {\n\n        if (enc->flags & AV_CODEC_FLAG_PASS1)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 1");\n\n        if (enc->flags & AV_CODEC_FLAG_PASS2)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 2");\n\n    }\n\n    bitrate = get_bit_rate(enc);\n\n    if (bitrate != 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %"PRId64" kb/s", bitrate / 1000);\n\n    } else if (enc->rc_max_rate > 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", max. %"PRId64" kb/s", (int64_t)enc->rc_max_rate / 1000);\n\n    }\n\n}\n',
]
with torch.no_grad():
    input_sentence = tokenizer(code)
    ids = input_sentence.padded_tensor(padding_id=0, pad_left=False)
    mask = input_sentence.attention_mask(pad_left=False)
    model_output = model(piece_ids=ids, attention_mask=mask)

I went through the API docs and skimmed through source code and it appears that truncation is not supported. Note that when I manually truncated the sequence, I was able to feed it to the RoBERTa encoder.

@MootezSaaD MootezSaaD changed the title Truncation sequences that are beyond the model's maximum length Truncation of sequences that are beyond the model's maximum length Jan 14, 2024
@shadeMe shadeMe added feat/tokenization Feature: Tokenization/piecer type/bug Type: Bug type/feature Type: Feature labels Jan 14, 2024
@shadeMe
Copy link
Collaborator

shadeMe commented Jan 14, 2024

Thanks for the report! As you surmised, we don't currently support the truncation of inputs, but that error message can definitely be improved. We'll look into it, but please feel free to contribute a PR if you'd like to sort it out yourself 😃

@danieldk
Copy link
Collaborator

Just wanted to add that we do support longer sequences with Curated Transformers in spaCy. We should probably provide something similar in Curated Transformers that could be used as an extension.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feat/tokenization Feature: Tokenization/piecer type/bug Type: Bug type/feature Type: Feature
Projects
None yet
Development

No branches or pull requests

3 participants