Hassan

Building an LLM - Episode 1, Getting The Data

5 minutes (1170 words)

You’ve probably heard about AI and its transformative potential countless times. Every company in the country is either pitching or planning ‘AI’ deployment. Its role in the future of work is almost guranteed. But what is it all about? How does a computer learn to respond to questions, create illustrations, or suggest books based on your tastes? Over the next few posts, starting with simple code examples, I aim to answer those questions.

While there are myriad model types and architectures, the ones you might be most familiar with are chatbots like ChatGPT or Llama 2. These transformer-based large language models (LLMs) go through a three-step training procedure:

  1. Pre-training
  2. Supervised fine-tuning(sft)
  3. Alignment

The pre-training phase is where a model is enriched with knowledge from an extensive dataset, typically text-based. This phase can be likened to populating a spreadsheet with information. The next two steps, sft and alignment, are about shaping the response of the model. This can be likened to giving someone a script on how to use the information contained within the spreadsheet from the previous phase.

To illustrate, let’s build a model that can transcribe the genes of a virus in their entirety. Viruses serve as an ideal learning tool because their sequence length is compact enough to fit within the context window of a standard LLM (around 3,000 tokens). Don’t worry, though; we won’t be training the model to high enough precision to actually pose a threat, as our model will be relatively small. Nonetheless, this exercise should provide a clearer understanding of the model-building process.

Here’s a snapshot of potential data sources:

Potential data sources

Now lets a pick source and begin coding! (Note: Some variables and steps have been removed for brevity and are accessible in the full notebooks linked below.)

import ftplib
from Bio import SeqIO
import glob
import pandas as pd
from datasets import load_dataset
import sys

ftp_server = 'ftp.bvbrc.org'
ftp_session= ftplib.FTP(ftp_server)
ftp_session.login()
ftp_session.retrlines('LIST')
ftp_session.cwd('viruses')

# The list of files being retrieved is available in the full notebook.
for item in file_list:
  ftp_get_file(ftp_session, item)  # Ditto for the ftp_get_file function.

Next, let’s read the files, process the data, and save it as parquet files, which are optimized for handling large datasets.

for file in glob.glob("*.fna"):
    f_name = str(file).replace('.fna', '')
    viruses = []
    try:
      for record in SeqIO.parse(file,"fasta"):
        record_id = str(record.id)
        record_seq = str(record.seq)
        record_name = str(record.name)
        record_description = str(record.description)
        record_no_of_features = len(record.features)
        record_seq_length = len(str(record.seq))
        virus = [record_id, record_seq, record_name, record_description,
                 record_no_of_features, record_seq_length]
        viruses.append(virus)
    except:
      print(f_name + ' has an error')
    df = pd.DataFrame(viruses, columns = ['id', 'sequence', 'name',
                                          'description', 'features',
                                          'seq_length'])
    df.to_parquet(f'{f_name}.parquet')
    print(f"Function call finished for {f_name}")

Once that’s done, it’s time to load the data into a huggingface dataset.

dataset = load_dataset("parquet", data_files="/content/dna/*.parquet")

The data

Now that we’ve loaded our data, it’s time for some preliminary preprocessing and quality checks. This entails defining a minimum and maximum sequence length, assessing sequence quality, addressing missing values, and deduplicating based on sequences. While we could delve deeper into these processes, we’ll reserve that discussion for our subsequent post.

data = dataset.drop_duplicates(subset=['sequence']).copy()
data = data[data['seq_length']< 50_000]
data = data[data['seq_length']> 5_000]
data['missing_seq_count'] = data.sequence.str.count('n')
data['missingness'] = data['missing_seq_count'] / data['seq_length']
data = data[data.missingness < 0.01].copy()
data['seq_filled'] = data['sequence'].apply(replace_non_nucleotide_with_random)

sequences = data['seq_filled'].to_list()
signatures = []
# Create a list of MinHash signatures for each sequence
for k in range(len(sequences)):
  minihash = MinHash(n=1000, ksize=7)
  minihash.add_sequence(sequences[k])
  signatures.append(minihash)

unique_signatures = []
unique_sequences = []
for i, sig in enumerate(signatures):
    is_similar = any([sig.jaccard(uni_sig) > 0.9 for uni_sig in unique_signatures])
    if not is_similar:
        unique_signatures.append(sig)
        unique_sequences.append(sequences[i])

Finally, our last step involves performing a left join on the unique sequences and saving the dataset. This sets the stage for the next—and arguably the most crucial—step: Pre-training!

🔗References

Tags: #LLMs #biology #dataset_curation #generative_learning