Example: Fine-tuning & Serving
Learn how to fine-tune and serve a large language model (LLM) for your application. Predibase offers you the ability to seamlessly put open-source LLMs in production without the training headaches or GPU setup.
In this guide, we will fine-tune and serve a text summarizer using mistral-7b, an open source LLM from Mistral. We will be using the same dataset we used for the News Headline Generation task in LoraLand.
Supported Models
Predibase supports many popular OSS models for fine-tuning including:
- llama-3-1-8b-instruct
- mistral-7b-instruct-v0-2
- qwen2-7b
To see all models available to fine-tune, check out the full list of available models.
Prepare Data
Predibase supports a variety of different data connectors including File Upload, S3, Snowflake, Databricks, and more. You can also upload your data in a few different file formats. We usually recommend CSV or JSONL. See more details about dataset preparation for fine-tuning.
Instruction Tuning
Your dataset should follow the following structure:
- prompt: The fully materialized input to your model
- completion: The output your model
In the case of JSONL, it should look something like:
{"prompt": ..., "completion": ...}
{"prompt": "Please summarize the following article ...", "completion": "Madonna kicks off Celebration World Tour in London"}
{"prompt": "Please summarize the following article ...", "completion": "Facebook Releases First Transparency Report on what Americans see on the platform"}
{"prompt": ..., "completion": ...}
Train Adapter
Setup environment
Setup a venv and install the Python SDK (if running locally - not required for Google Collab or similar):
python3.9 -m venv .venv
source .venv/bin/activate
pip install predibase
You can use the Web UI or the Python SDK to connect your data and start a fine-tuning job. Currently the SDK only support file uploads, so if you wish to use a data connection (ex. S3), you'll need to use the UI.
Initialize Predibase Client
Then from a Python script or notebook (using the venv created above):
from predibase import Predibase, FinetuningConfig, DeploymentConfig
pb = Predibase(api_token="<PREDIBASE API TOKEN>")
You can generate an API token on the homepage or find an existing key under Settings > My Profile.
Connect Data to Predibase
We will use the tldr_news
dataset that is at this Google Drive Link. This dataset has been pre-formatted with the instruction template for mistral-7b. You can find the original dataset with a numerical split column and no prompt template on HuggingFace.
dataset = pb.datasets.from_file("</path/to/tldr_dataset.csv>", name="tldr_dataset")
Kick-off Training
We can start a fine-tuning job with the recommended defaults and see live metric updates. See the list of available models. (Note: you must use the short names provided in the list.) You may also kick off a non-blocking training job.
# Create an adapter repository
repo = pb.repos.create(name="news-summarizer-model", description="TLDR News Summarizer Experiments", exists_ok=True)
# Start a fine-tuning job, blocks until training is finished
adapter = pb.adapters.create(
config=FinetuningConfig(
base_model="mistral-7b"
),
dataset=dataset, # Also accepts the dataset name as a string
repo=repo,
description="initial model with defaults"
)
Customize Hyperparameters (Optional)
Predibase allows you to configure hyperparameters for your training job. See full list of available parameters
# Create an adapter repository
repo = pb.repos.create(name="news-summarizer-model", description="TLDR News Summarizer Experiments", exists_ok=True)
# Start a fine-tuning job with custom parameters, blocks until training is finished
adapter = pb.adapters.create(
config=FinetuningConfig(
base_model="mistral-7b",
adapter="turbo_lora", # default: "lora"; Turbo LoRA is a proprietary fine-tuning method which greatly improves inference throughput for longer output token tasks.
epochs=1, # default: 3
rank=8, # default: 16
learning_rate=0.0001, # default: 0.0002
target_modules=["q_proj", "v_proj", "k_proj"], # default: None (infers [q_proj, v_proj] for mistral-7b)
apply_chat_template=False, # default: False
),
dataset=dataset,
repo=repo,
description="changing epochs, rank, and learning rate"
)
See more about apply_chat_template
here: https://docs.predibase.com/user-guide/fine-tuning/chat_templates
Each adapter version has its own page in the Adapters UI.
Monitor Progress
You can always return to your fine-tuning job in the SDK.
# Get adapter, blocking call if training is still in progress
adapter = pb.adapters.get("news-summarizer-model/1")
Predibase also supports progress tracking through Weights & Biases or Comet. In order to use this feature, you must have a Weights & Biases / Comet account and API key, which can be added in the Predibase UI under Settings > My Profile.
Use Your Adapter
Start by prompting your adapter using a shared endpoint and then once you're happy with your adapter's performance, create a private serverless deployment for production use.
Shared Endpoints (Free)
Shared endpoints are a public resource we offer for getting started, experimentation, and fast iteration. If your base model is one that is hosted as a shared endpoint, you can use your fine-tuned model instantly by utilizing LoRAX:
input_prompt="<s>[INST] The following passage is content from a news report. Please summarize this passage in one sentence or less. \n Passage: Memray is a memory profiler for Python. It can help developers discover the cause of high memory usage, find memory leaks, and find hotspots in code that cause a lot of allocations. Memray can be used both as a command-line tool or as a library. \n Summary: [/INST] "
lorax_client = pb.deployments.client("mistral-7b")
print(lorax_client.generate(input_prompt, adapter_id="news-summarizer-model/1", max_new_tokens=100).generated_text)
The first line returns a LoRAX client that we can use for prompting. The second line calls generate
while passing in the adapter repo name and version to prompt our fine-tuned model.
For this query, the expected response should look something like Memray (GitHub Repo)
.
Note that LLMs are non-deterministic so your response may look slightly different.
We can compare the fine-tuned model to the base model by calling generate without adapter_id
:
print(lorax_client.generate(input_prompt, max_new_tokens=100).generated_text)
Please note that serverless endpoints are subject to rate limits to ensure fair usage among all customers.
Private Serverless Deployments ($/GPU-hour)
Once you're ready for production, deploy a private instance of the base model for greater reliability, control, and no rate limiting. LoRAX enables you to serve an unlimited number of adapters on a single base model deployment.
Predibase officially supports serving these models. Note that by default private serverless deployments spin down after 12 hours of no activity. (To change this, set min_replicas
to 1.) For the base_model
, you'll need the model name, which can be found here for the models we officially support.
Private serverless deployments are available to Developer and Enterprise tier users. To upgrade to Developer tier, Free tier users will need to add a credit card to automatically upgrade.
# Deploy
pb.deployments.create(
name="my-mistral-7b",
config=DeploymentConfig(
base_model="mistral-7b",
# cooldown_time=3600, # Value in seconds, defaults to 3600 (1hr)
min_replicas=0, # Auto-scales to 0 replicas when not in use
max_replicas=1
)
# description="", # Optional
)
# Prompt
input_prompt="<s>[INST] The following passage is content from a news report. Please summarize this passage in one sentence or less. \n Passage: Memray is a memory profiler for Python. It can help developers discover the cause of high memory usage, find memory leaks, and find hotspots in code that cause a lot of allocations. Memray can be used both as a command-line tool or as a library. \n Summary: [/INST] "
lorax_client = pb.deployments.client("my-mistral-7b")
print(lorax_client.generate(input_prompt, adapter_id="news-summarizer-model/1", max_new_tokens=100).generated_text)
Learn more about private serverless deployments.
Delete Deployment
By default your deployment scales to 0 replicas. While it's scaled to 0, you won't be billed and as soon as you send a request, your deployment will automatically scale up. If you'd like, you may also delete your deployment if you don't intend to use it.
pb.deployments.delete("my-mistral-7b") # The name must match the name used when creating the private serverless deployment
Next Steps
- Try training with your own dataset and use case
- Try training with a larger model (i.e. mixtral-8x7b) to compare performance