Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use Model to get transaction embeddings? #16

Open
bjchaudhari29 opened this issue Jul 8, 2021 · 13 comments
Open

How to use Model to get transaction embeddings? #16

bjchaudhari29 opened this issue Jul 8, 2021 · 13 comments

Comments

@bjchaudhari29
Copy link

Hi Team,
Thanx a lot for sharing the code.
I was able to train Bert model on card dataset but I am facing issue while loading saved model to generate embeddings. Can you please let me know the way to load model weights and way to generate embedding for transaction.

After creating instance of class TabFormerBertLM I am trying to load weights by following command.

tab_net.from_pretrained('/content/drive/MyDrive/TabFormer/checkpoint-500/pytorch_model.bin')

After running this I am getting the following error.
AttributeError: 'TabFormerBertLM' object has no attribute 'from_pretrained'

It will be very helpful if you can guide me to solve this problem.

Thank you.

@sevstafiev
Copy link

Hey, did you solve the problem?

@bjchaudhari29
Copy link
Author

Not yet.

@ink-pad
Copy link
Collaborator

ink-pad commented Jul 27, 2021

Hi @bjchaudhari29 / @sevstafiev :

Apologies for not getting back on this earlier.

I know why you are seeing the issue is! I will try to get back on this later this week, when I get time, and share some code snippet to show how to load it properly. However, if you can't wait until then, please look at the branch gpt_cc_user_eval which does exactly same thing that you are attempting to but on GPT model.

@sevstafiev
Copy link

It would be amazing if you share code snippet for bert model! Thanks.

@bjchaudhari29
Copy link
Author

Hi @ink-pad ,
Thank you for the reply however I am not able able to access gpt_cc_user_eval branch.

@ink-pad
Copy link
Collaborator

ink-pad commented Jul 30, 2021

@bjchaudhari29 :

My bad - I have fixed the link now!

@bjchaudhari29
Copy link
Author

Hi @ink-pad ,
Did you get time to work on loading of BERT model to get transaction level embedding. It would be great if you can share that code.

Thank you.

@bjchaudhari29
Copy link
Author

Hi ,
I am trying loading BERT model as loaded GPT in given link (https://github.com/IBM/TabFormer/blob/gpt_cc_user_eval/gpt_eval.py) it is giving following error.

inferencer = tab_net.model.from_pretrained('/content/drive/MyDrive/TabFormer/checkpoint-500', vocab=dataset.vocab).to(device)

File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_utils.py", line 844, in from_pretrained
config, model_kwargs = cls.config_class.from_pretrained(
AttributeError: 'NoneType' object has no attribute 'from_pretrained'

@kekayan
Copy link

kekayan commented Dec 18, 2021

@bjchaudhari29
try this to load the model.

config = TabFormerBertConfig.from_pretrained(
    "/output/checkpoint-2000/config.json"
)
model = tab_net.model.from_pretrained("/output/checkpoint-2000/pytorch_model.bin", 
config=config ,vocab=dataset.vocab).to(device)

@kekayan
Copy link

kekayan commented Dec 20, 2021

Hi @ink-pad ,

Thanks for open-sourcing the code. I have a question, Do we need to change the line 118 in tabformer_bert.py,
from outputs = (prediction_scores,) + outputs[2:]
to outputs = (prediction_scores,) + outputs . to get the transaction/row embeddings for each row ?

@monk1337
Copy link

Any update on this issue?

@shaoyijia
Copy link

Hi @ink-pad ,

Thanks for open-sourcing the code. I have a question, Do we need to change the line 118 in tabformer_bert.py, from outputs = (prediction_scores,) + outputs[2:] to outputs = (prediction_scores,) + outputs . to get the transaction/row embeddings for each row ?

Hi @kekayan ,

Have you found the correct way to get the row embeddings for each row? I want to reproduce the results on Fraud Detection Task and have the same problem.

Thanks for your help!

@shamgane
Copy link

Hi, I have the same query. Any update on how to obtain the row embeddings for each row? Some code level guidance would help :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants