Hassan

Building an LLM - Episode 2, Llama smol

6 minutes (1445 words)

🔗The Llama Model

Hey folks! Today, I’m super excited to introduce you to the Llama model series. Developed by the brilliant minds at Meta, Llama is a decoder-style LLM that’s rooted in the transformer architecture. With its simple design, it’s just the perfect candidate for our first model.

🔗Setting up the Environment

First things first, lets set the stage. And by that, I mean getting all the tools and libraries we need. We’ll be making use of several libraries such as tokenizers, datasets, transformers, and sentencepiece. A quick pip install, and we’re off!

!pip install x-transformers tokenizers datasets transformers sentencepiece

🔗Preparing the Data

Now, onto the fun stuff. The dataset we’ll be working with is ‘virus_dna_dedup_minihash_0.9_kmer_7’. See our previous post if you’re wondering how we’ve created it. After we’ve loaded our dataset, we insert a space every seven characters to break down dna sequences into smaller segments. Think of it as making bite-sized chuncks for our tokenizer.

virus_ds = load_dataset('Hack90/virus_dna_dedup_minihash_0.9_kmer_7')
virus_ds = virus_ds.map(insert_spaces)

The sequences are then tokenized using SentencePiece, a popular tokenization library.

spm.SentencePieceTrainer.train(input='seqs', model_prefix='dna', vocab_size=400
                               ,character_coverage=1.0, model_type='bpe', max_sentence_length= 150_000_000)

🔗The Models

Example Architecture Llama models are a family of decoder-only transformer models. Although not fully accurate, you can conceptulise the models of consisting of three main layers/components:

  1. Embedding Layer: This layer converts the input tokens into vectors of a fixed size.
  2. Llama Blocks: This part consists of the main building blocks of the model. Each block consists of:
    • RMSNorm: A normalisation sublayer.
    • RoPEMaskedMultiheadAttention: Rotary positional encoding combined with multi-head attention.
    • Feedforward: A feedforward neural network with a SwiGLU (Swish-Gated Linear Unit) activation.
  3. Final Feedforward Network: This layer takes the output from the Llama blocks and produces our final logits.
class LlamaBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.rms = RMSNorm((config['context_window'], config['d_model']))
        self.attention = RoPEMaskedMultiheadAttention(config)
        self.feedforward = nn.Sequential(
            nn.Linear(config['d_model'], config['d_model']),
            SwiGLU(config['d_model']),
        )
    def forward(self, x):
        x = self.rms(x) # rms pre-normalization
        x = x + self.attention(x)

        x = self.rms(x) # rms pre-normalization
        x = x + self.feedforward(x)
        return x

class Llama(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config['vocab_size'], config['d_model'])
        self.llama_blocks = nn.Sequential(
            OrderedDict([(f"llama_{i}", LlamaBlock(config)) for i in range(config['n_layers'])])
        )
        self.ffn = nn.Sequential(
            nn.Linear(config['d_model'], config['d_model']),
            SwiGLU(config['d_model']),
            nn.Linear(config['d_model'], config['vocab_size']),
        )
        print("model params:", sum([m.numel() for m in self.parameters()]))

    def forward(self, idx, targets=None):
        x = self.embeddings(idx)
        x = self.llama_blocks(x)
        logits = self.ffn(x)
        if targets is None:
            return logits
        else:
            loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
            return logits, loss

🔗Training

The model is trained using the Adam optimizer. During training, the model is evaluated at regular intervals to monitor the validation loss.

optimizer = torch.optim.Adam(llama.parameters())
train(llama, optimizer, config=MASTER_CONFIG)

🔗Generation

After training, the Llama model can be used to generate new DNA sequences. The generate function takes in a model and produces a sequence of tokens, which can be decoded back into a DNA sequence.

print(generate(llama, MASTER_CONFIG, 500)[0])

Or instead of doing the previous steps, we can instanstiate a model using the x-transformers library from lucidrains.

model = TransformerWrapper(
    num_tokens = 400,
    max_seq_len = 512,
    attn_layers = Decoder(
        dim = 32,
        depth = 32,
        heads = 8, 
        use_simple_rmsnorm = True,
        alibi_pos_bias = True,
        alibi_num_heads = 4    
    )
)

As you can see, creating a model is relatively straightforward. The real magic lies in the quality and quantity of training data. So far, we’ve amassed approximately 35 million tokens of deduplicated viral DNA sequences. However, it’s worth noting that the most advanced models are trained on datasets exceeding the 1 trillion token mark. With that in mind, let’s refocus our efforts on data collection. In our next post, we plan to embark on the ambitious task of gathering every dna sequence available in the NCBI GeneBank and meticulously de-duplicating that dataset.

🔗References

Tags: #LLMs #biology #modelling #generative_learning