Skip to content

TJKlein/Essential-Transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

The Essential Transformer

Understanding the backbone of encoders and decoders in 45-minutes

The Transformer architecture has revolutionized natural language processing and machine translation. This GitHub repository provides a minimalist yet comprehensive implementation of the Transformer architecture's encoder and decoder components, aimed at providing an intuitive understanding of the core concepts underlying this powerful model. The implementations serve as a didactic resource for enthusiasts, researchers, and learners who wish to grasp its fundamental principles. Each implementation needs less than 100 lines of code.

To keep things simple, a couple of assumptions are made:

  • positional embeddings are treated as trainable that are added to the token embeddings
  • the embedding dimensionality must be a multiple of the number of heads (the joint embedding is reshaped before softmax normalization)
  1. Toy example of instantiating a decoder block:
import torch
from decoder import Transformer, TransformerBlock

# Some toy parameters
num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10

# Toy input data corresponding to embeddings
x = torch.randn((batch_sz,max_len,emb_dim))

tb = TransformerBlock(max_len, emb_dim, ffn_dim, num_heads)
tb(x)
  1. Toy example of instantiating a transformer decoder:
import torch
from decoder import Transformer

# Some toy parameters
num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10

# Toy input data corresponding to random tokens
x = torch.randint(0,vocab_sz,(batch_sz, max_len))

trans = Transformer(num_layers, num_heads, max_len, vocab_sz, emb_dim, ffn_dim)
trans(x)
  1. Toy example of instantiating a transformer decoder with multi-query attention:
import torch
from decoder_multi_query_attention import Transformer

# Some toy parameters
num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10

# Toy input data corresponding to random tokens
x = torch.randint(0,vocab_sz,(batch_sz, max_len))

trans = Transformer(num_layers, num_heads, max_len, vocab_sz, emb_dim, ffn_dim)
trans(x)
  1. Toy example of instantiating a transformer encoder:
import torch
from encoder import Transformer

num_heads = 16
emb_dim = 768
ffn_dim = 1024
num_layers = 12
max_len = 128
vocab_sz = 10000
batch_sz = 10
# Toy input data corresponding to random tokens
x = torch.randint(0,vocab_sz,(batch_sz, max_len))

trans = Transformer(num_layers, vocab_sz, emb_dim, max_len, num_heads, ffn_dim)
trans(x)

About

A minimalist 45 minutes implementation of the transformer backbone (encoder, decoder)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages