Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to obtain perplexity evaluation datasets? #11

Open
LGH1gh opened this issue Dec 14, 2022 · 1 comment
Open

How to obtain perplexity evaluation datasets? #11

LGH1gh opened this issue Dec 14, 2022 · 1 comment

Comments

@LGH1gh
Copy link

LGH1gh commented Dec 14, 2022

Dear Author,

Thanks for releasing the RITA for protein generation!
However, I wonder how can I obtain perplexity evalutation datasets used in your paper and how to calculate perplexity.
Hope for your suggestions. Thanks in advance!

@Detopall
Copy link

You can use the following code to calculate the perplexity. Can't really help you with obtaining perplexity evalutation of the datasets used in their paper

import math
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline

model = AutoModelForCausalLM.from_pretrained("lightonai/RITA_s", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("lightonai/RITA_s")

rita_gen = pipeline('text-generation', model=model, tokenizer=tokenizer)
sequences = rita_gen("MAB", max_length=200, do_sample=True, top_k=950, repetition_penalty=1.2, 
                     num_return_sequences=2, eos_token_id=2)

def calculatePerplexity(sequence, model, tokenizer):
    input_ids = torch.tensor(tokenizer.encode(sequence)).unsqueeze(0) 
    input_ids = input_ids.to(model.device)

    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]

    return math.exp(loss)

for seq in sequences:
    print(f"seq: {seq['generated_text'].replace(' ', '')}")
    ppl = calculatePerplexity(seq['generated_text'], model, tokenizer)
    print(f"Perplexity: {ppl}\n")

With these results:

seq: MABVVGTALYPGSDRFDGEYEVDIVIDTDGARYVLPVINTITHVKQGTSTRHPLGKAGQARKYATMHTGNLVLHLFDKGHTGVSIHGTSIDERIFGADGRVIAEAQGSGDMRHYGISPNRVAVCVARPFGGEGFSVPLSIHALGNETGVQTTGSGDVSTTSAVEGPAQEQMGFLDHTLSYASSTILTYRTQVTTGLGGAR
Perplexity: 132566.77587907546

seq: MABPVVTREPGVYFLAPRVSKFYEIIPWWNEMYVIECSIVSAAAGAPAVTPIQIRAPDVDIMSQVTSTAGMTAFVKVKRSRVIKMYQRVEPVERLHALVGGASILLDASLPQAALVTIEGGDIFEVFHGTEGLLAIIDGAIQQGLFSYKM
Perplexity: 127686.55561821107

The lower the perplexity score the better. The lower perplexity of the second sequence suggests that it is more coherent and natural-sounding according to the language model, and is likely a better-quality sequence compared to the first one.

Hope this helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants