Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong tokens at the end of the sequence #63

Open
pabloacera opened this issue Jan 10, 2024 · 0 comments
Open

Wrong tokens at the end of the sequence #63

pabloacera opened this issue Jan 10, 2024 · 0 comments

Comments

@pabloacera
Copy link

pabloacera commented Jan 10, 2024

Hi,

Thanks for making the model available.
I have been playing with the model and realized that usually when making prediction of a sequence of DNA, usually the last token is not the one in the original sequence. The predictions usually have some extra nucleotides at the end of the sequence.

Am i missing something? Is this the expected behavior? Is there a expected nucleotide input length which fix this behavior?


sequences = ['TGGAAAGTTGGGGACAAATGTTCTGCCATTTGGTCAGAAGACGGTTGCATTTACCCAGCTACCATTGCTTCAATTGATTTTAAGAGAGAAACCTGTGTTGTGGTTTACACTGGATATGGAAATAGAGAGGAGCAAAATCTGTCCGATCTACTTTCCCCAATCTGTGAAGTAGCTAATAATATAGAACAAAATGCTCAAGAG',        'ATAATTCCCCCACCACCTCCCATATGTCCAGATTCTCTTGATGATGCTGATGCTTTGGGAAGTATGTTAATTTCATGGTACATGAGTGGCTATCATACTGGCTATTATATG',
'GACAAGGCCGTGGCGGAGCCTGTCAGCCGCCTGCTGGAGAGCACGCTCAGGAGCAGCCACCTGCCCAGCAGGGTTGGAGCCCTGCACGGCGTCCTCTATGTGCTGGAGTGCGACCTGCTGGACGACACTGCCAAGCAGCTCATCCCGGTCATCAGCGACTATCTCCTCTCCAACCTGAAAGGGATCGCCCA']


for dna in sequences:
    dna = dna[:128]
    
    inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
    hidden_states = model(inputs)
    
    logits = hidden_states.logits
    
    # Apply softmax to convert logits to probabilities
    probabilities = softmax(logits, dim=-1)
    
    # Choose the most likely token for each position
    predicted_token_ids = torch.argmax(probabilities, dim=-1)
    
    print('original tokens', inputs)
    print('predicted tokens', predicted_token_ids)
    print()
    
    # Convert these token ids back to nucleotides
    predicted_sequences = [tokenizer.decode(token_ids) for token_ids in predicted_token_ids[:,1:]]
    original = [tokenizer.decode(token_ids) for token_ids in inputs]
    
    print('Original', dna)
    print('Predicted',' '.join(predicted_sequences).replace(' ', ''))
    print()

original tokens tensor([[   1,   11,   45,  316, 1823,   48,  776,   86,   67,  330,  583, 1867,
           95,  105,  173,   60,  162, 3713, 2030,   13,  306,  922,   80,  438,
           70,  609,   50,    2]])
predicted tokens tensor([[ 371,   11,   45,  316, 1823,   48,  776,   86,   67,  330,  583, 1867,
           95,  105,  173,   60,  162, 3713, 2030,   13,  306,  922,   80,  438,
           70,  609,   50,  198]])

Original TGGAAAGTTGGGGACAAATGTTCTGCCATTTGGTCAGAAGACGGTTGCATTTACCCAGCTACCATTGCTTCAATTGATTTTAAGAGAGAAACCTGTGTTGTGGTTTACACTGGATATGGAAATAGAGA
Predicted TGGAAAGTTGGGGACAAATGTTCTGCCATTTGGTCAGAAGACGGTTGCATTTACCCAGCTACCATTGCTTCAATTGATTTTAAGAGAGAAACCTGTGTTGTGGTTTACACTGGATATGGAAATAGAGACAGTT

original tokens tensor([[   1,    5,  109,  906,   34,  209,  902,   16,  410,  149, 2659,  590,
          157,   57,   35, 2368,  224,   35,  246,   30,  105,   22,  236, 2463,
           70,    2]])
predicted tokens tensor([[   5,    5,  109,  906,   34,  209,  902,   16,  410,  149, 2659,  590,
          157,   57,   35, 2368,  224,   35,  246,   30,  105,   22,  236, 2463,
           70,   82]])

Original ATAATTCCCCCACCACCTCCCATATGTCCAGATTCTCTTGATGATGCTGATGCTTTGGGAAGTATGTTAATTTCATGGTACATGAGTGGCTATCATACTGGCTATTATATG
Predicted ATAATTCCCCCACCACCTCCCATATGTCCAGATTCTCTTGATGATGCTGATGCTTTGGGAAGTATGTTAATTTCATGGTACATGAGTGGCTATCATACTGGCTATTATATGTCCA

original tokens tensor([[   1,  225,  136,   30,  708,  192, 1066,  192, 1717,   32,  118,  591,
         2310,   74,   95,  253,  793,   36,  335,   72,  578,   88, 2621,  215,
           93,   74,  438,   93,   12,    6,    2]])
predicted tokens tensor([[  13,  225,  136,   30,  708,  192, 1066,  192, 1717,   32,  118,  591,
         2310,   74,   95,  253,  793,   36,  335,   72,  578,   88, 2621,  215,
           93,   74,  438,   93,   12,    6,   92]])

Original GACAAGGCCGTGGCGGAGCCTGTCAGCCGCCTGCTGGAGAGCACGCTCAGGAGCAGCCACCTGCCCAGCAGGGTTGGAGCCCTGCACGGCGTCCTCTATGTGCTGGAGTGCGACCTGCTGGACGACAC
Predicted GACAAGGCCGTGGCGGAGCCTGTCAGCCGCCTGCTGGAGAGCACGCTCAGGAGCAGCCACCTGCCCAGCAGGGTTGGAGCCCTGCACGGCGTCCTCTATGTGCTGGAGTGCGACCTGCTGGACGACACGGGG



Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant