Fine-Tuning Large Datasets (Beta)
This guide will walk you through the process of fine-tuning LLMs using very large datasets in Predibase.
What is Considered a Large Dataset?
In the context of LLM fine-tuning, a large dataset is typically considered to be greater than 1 GB in size.
Datasets of this size require special handling and efficient processing techniques to ensure smooth fine-tuning operations. Predibase requires that these datasets are tokenized in advance and saved in partitioned arrow format compatible with HuggingFace's Datasets library.
Steps to Prepare and Upload Your Dataset
Follow these steps to get your large dataset ready for fine-tuning in Predibase:
1. Load Your Dataset
You can load your dataset either in-memory or out-of-memory, depending on its size and your available resources. Here's an example using the Hugging Face datasets library:
from datasets import load_dataset
# Load an in-memory dataset
dataset = load_dataset("csv", data_files="your_dataset.csv")
# For out-of-memory processing, use streaming
dataset = load_dataset("csv", data_files="your_dataset.csv", streaming=True)
2. Load Your Tokenizer
Use the transformers library to load the appropriate tokenizer for the LLM that you want to use. Note that some LLMs like Llama-3 and Mistral require creating a Huggingface account and requesting access since they are gated.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id
3. Batch Tokenize Your Data
You will need to tokenize the data ahead of time.
- If you're doing instruction tuning, you have to tokenize the
prompt
andcompletion
columns independently. - If you're doing completions style training, you just need to tokenize your
text
column.
def tokenize_data(examples):
prompt_tokens = tokenizer(examples["prompt"], truncation=True)['input_ids']
completion_tokens = tokenizer(examples["completion"], truncation=True)['input_ids']
return {"prompt": prompt_tokens, "completion": completion_tokens}
tokenized_dataset = dataset.map(tokenize_data, batched=True)
4. Create input_ids and labels
Next, you need to concatenate the tokenized prompt and completion, then add an EOS token at the end. The process varies based on your training approach:
- For instruction tuning, you need to concat your prompt tokens with your completion tokens
- For completions style training, you can just re-use your input tokens as your labels.
def prepare_train_features(examples):
prompt_ids = examples["prompt"]
completion_ids = examples["completion"]
input_ids = prompt_ids + completion_ids + [tokenizer.eos_token_id]
# For instruction tuning
labels = [-100] * len(prompt_ids) + completion_ids + [tokenizer.eos_token_id]
# For completions style training
# labels = input_ids.copy()
# For multi-turn conversation
# labels = [-100 if token in user_role_tokens else token for token in input_ids]
return {"input_ids": input_ids, "labels": labels}
prepared_dataset = tokenized_dataset.map(prepare_train_features)
5. Create a Split Column (Optional)
Add a split column to your dataset for training and evaluation:
def add_split_column(example):
example["split"] = "train" if random.random() > 0.95 else "evaluation"
return example
final_dataset = prepared_dataset.map(add_split_column)
final_dataset = final_dataset.remove_columns(
[col for col in final_dataset.column_names if col not in ['input_ids', 'labels', 'split']]
)
6. Save the Dataset
Save your prepared dataset to disk or directly to S3 using the Huggingface Datasets library. Note, it must be saved as a multi-partition arrow file using the Huggingface Datasets format.
from datasets import Dataset
# Save locally
final_dataset.save_to_disk("path/to/local/directory")
# Or save directly to S3
final_dataset.save_to_disk(
"s3://your-bucket-name/path/to/dataset",
key="<your_aws_access_key_id>",
secret="<your_aws_secret_access_key>"
)
If you decide to save locally, you will need to upload your dataset folder to S3.
7. Upload to Predibase
Finally, use the Predibase S3 connector to upload your dataset through the product interface.