Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantize_embeddings + KeyedJaggedTensor+ vbe cannot work #1894

Open
yjjinjie opened this issue Apr 18, 2024 · 3 comments
Open

quantize_embeddings + KeyedJaggedTensor+ vbe cannot work #1894

yjjinjie opened this issue Apr 18, 2024 · 3 comments

Comments

@yjjinjie
Copy link



import torch
from torchrec import KeyedJaggedTensor
from torchrec import EmbeddingBagConfig,EmbeddingConfig
from torchrec import EmbeddingBagCollection,EmbeddingCollection


kt2 = KeyedJaggedTensor(
    keys=['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3', 'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3', 'click_50_seq__raw_1'], 
    values=torch.tensor([573174,   5073,   3562,      3,     18,     13,     11,     49,     26,
             4,      2,      2,      4,      2,      4, 736847, 849333, 997432,
        640218,   9926,   9926,      0,      0,      0,      0,  59926,  59926,
             0,      0,      0,      0,   2835,    769,   1265,   8232,   6399,
           114,   7487,   2876,    953,   7840,   7538,   7998,   7852,   3528,
          1475,   7620,   6110,    572,    735,   4405,   5655,   6736,   2173,
          3421,   2311,   7122,   2159,   4535,   2162,   4657,   3151,   4522,
          1075,    306,   8968,   2056,   2256,   3919,   8624,   5372,   6018,
          3861,   4114,   3984,   2287,   1481,   4757,   1189,   2518,    913,
          9421,   3093,   5911,   9704,   8168,   9410,    728,   2451,    243,
          5187,   5836,   8830,   4894,    614,   7705,   9258,   3518,   4434,
             4,      2,      4,      2,      4,      2,      3,      2,      2,
             3,      3,      3,      4,      4,      3,      0,      4,      0,
             2,      2,      3,      4,      4,      0,      2,      2,      4,
             0,      3,      2,      2,      3,      0,      4,      0,      4,
             4,      4,      2,      2,      3,      4,      2,      4,      3,
             4,      2,      4,      2,      2,      2,      2,      0,      3,
             4,      4,      3,      2,      4,      4,      4,      4,      3,
             2,      3,      4,      2,      4,      0,      4,      4,      4,
             4,      0,      0,      2,      1,      1,      0,      3,      4,
             4,      2,      4,      1,      1,      4,      2,      2,      4,
             0,      4,      4,      4,      4,      4,      1,      4,      2,
             0,      0,      0,      2,      4,      4,      2,      4,      2,
             4,      4,      1,      1,      4,      1,      4,      4,      1,
             0,      4,      4,      4,      3,      0,      0,      2,      4,
             2,      2,      4,      4,      4,      2,      2,      4,      2,
             3]),
    lengths=torch.tensor([ 1,  1,  1,  1,  0,  0,  1,  2,  2,  1,  1,  4,  2,  2,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1, 24, 44, 24, 44, 24, 44], dtype=torch.int64),
    stride_per_key_per_rank=[[1], [2], [2], [2], [2], [2], [1], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2]],
    inverse_indices=(['user_id', 'item_id', 'id_3', 'id_4', 'id_5', 'raw_1', 'raw_4', 'combo_1', 'lookup_2', 'lookup_3', 
                      'lookup_4', 'match_2', 'match_3', 'match_4', 'click_50_seq__item_id', 'click_50_seq__id_3', 
                      'click_50_seq__raw_1'], 
                     torch.tensor([[0, 0], [0, 1],[0, 1], [0, 1], [0, 1], [0, 1],[0, 0], [0, 1], [0, 1], [0, 1],
                                   [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1]])
    )
)

eb_configs2=[
    EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='user_id_emb', feature_names=['user_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=16, name='item_id_emb', feature_names=['item_id'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=8, name='id_3_emb', feature_names=['id_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None, need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100, embedding_dim=16, name='id_4_emb', feature_names=['id_4', 'id_5'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_1_emb', feature_names=['raw_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='raw_4_emb', feature_names=['raw_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000000, embedding_dim=16, name='combo_1_emb', feature_names=['combo_1'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='lookup_2_emb', feature_names=['lookup_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=1000, embedding_dim=8, name='lookup_3_emb', feature_names=['lookup_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='lookup_4_emb', feature_names=['lookup_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),
EmbeddingBagConfig(num_embeddings=100000, embedding_dim=16, name='match_2_emb', feature_names=['match_2'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=10000, embedding_dim=8, name='match_3_emb', feature_names=['match_3'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,   need_pos=False, ),
EmbeddingBagConfig(num_embeddings=5, embedding_dim=16, name='match_4_emb', feature_names=['match_4'], weight_init_max=None, weight_init_min=None, pruning_indices_remapping=None,  need_pos=False, ),

]
ebc = EmbeddingBagCollection(eb_configs2)

print(ebc(kt2))
from torchrec.inference.modules import quantize_embeddings


import torch
import torch.nn as nn

class EmbeddingGroupImpl(nn.Module):
    def __init__(self,ebc):
        super().__init__()
        self.ebc=ebc
    
    def forward(
        self,
        sparse_feature
    ):
        self.ebc(sparse_feature)

a=EmbeddingGroupImpl(ebc=ebc)
a.forward(kt2)

quant_model = quantize_embeddings(a, dtype=torch.qint8, inplace=True)
print(quant_model(kt2))

报错:

Traceback (most recent call last):
  File "/larec/tzrec/tests/test_per2.py", line 89, in <module>
    print(quant_model(kt2))
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/larec/tzrec/tests/test_per2.py", line 83, in forward
    self.ebc(sparse_feature)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torchrec/quant/embedding_modules.py", line 487, in forward
    else emb_op.forward(
  File "/opt/conda/lib/python3.10/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py", line 764, in forward
    torch.ops.fbgemm.bounds_check_indices(
  File "/opt/conda/lib/python3.10/site-packages/torch/_ops.py", line 758, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: offsets size 27 is not equal to B (1) * T (14) + 1
@yjjinjie
Copy link
Author

@henrylhtsang please see this problem

@PaulZhang12
Copy link
Contributor

I don't believe VBE + quantized EBC is yet supported. Quantized EBC uses a completely different FBGEMM TBE than the standard EBC for training

@yjjinjie
Copy link
Author

can you support VBE + quantized EBC for inference?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants