Skip to main content

Example: Fine-tuning & Serving

Open In Colab

Learn how to fine-tune and serve a large language model (LLM) for your application. Predibase offers you the ability to seamlessly put open-source LLMs in production without the training headaches or GPU setup.

In this guide, we will fine-tune and serve a text summarizer using mistral-7b, an open source LLM from Mistral. We will be using the same dataset we used for the News Headline Generation task in LoraLand.

Supported Models

Predibase supports many popular OSS models for fine-tuning including:

  • mistral-7b
  • mixtral-8x7b-v0-1
  • llama-3-8b-instruct
  • llama-3-70b

To see all models available to fine-tune, check out the full list of available models.

Prepare Data

Predibase supports a variety of different data connectors including File Upload, S3, Snowflake, Databricks, and more. You can also upload your data in a few different file formats. We usually recommend CSV or JSONL. See more details about dataset preparation for fine-tuning.

Instruction Tuning

Your dataset should follow the following structure:

  • prompt: The fully materialized input to your model
  • completion: The output your model

In the case of JSONL, it should look something like:


{"prompt": ..., "completion": ...}
{"prompt": "Please summarize the following article ...", "completion": "Madonna kicks off Celebration World Tour in London"}
{"prompt": "Please summarize the following article ...", "completion": "Facebook Releases First Transparency Report on what Americans see on the platform"}
{"prompt": ..., "completion": ...}

Train Adapter

You can use the Web UI or the Python SDK to connect your data and start a fine-tuning job. Currently the SDK only support file uploads, so if you wish to use a data connection (ex. S3), you'll need to use the UI.

Initialize Predibase Client

Install the client if needed:

pip install predibase

Then from a Python script or notebook:

from predibase import Predibase, FinetuningConfig, DeploymentConfig

pb = Predibase(api_token="<PREDIBASE API TOKEN>")

You can generate an API token on the homepage or find an existing key under Settings > My Profile.

Connect Data to Predibase

dataset = pb.datasets.from_file("/path/tldr_dataset.csv", name="tldr_dataset")

We will use the tldr_news that is at this Google Drive Link. You can find the original dataset with a numerical split column and no prompt template on HuggingFace.

Kick-off Training

We can start a fine-tuning job with the recommended defaults and see live metric updates. See the list of available models. (Note: you must use the short names provided in the list.)

# Create an adapter repository
repo = pb.repos.create(name="news-summarizer-model", description="TLDR News Summarizer Experiments", exists_ok=True)

# Start a fine-tuning job, blocks until training is finished
adapter = pb.adapters.create(
config=FinetuningConfig(
base_model="mistral-7b"
),
dataset=dataset, # Also accepts the dataset name as a string
repo=repo,
description="initial model with defaults"
)

Customize Hyperparameters (Optional)

Currently, we support modifying epochs, rank, and learning rate and are working to expose additional hyperparameters very soon!

# Create an adapter repository
repo = pb.repos.create(name="news-summarizer-model", description="TLDR News Summarizer Experiments", exists_ok=True)

# Start a fine-tuning job with custom parameters, blocks until training is finished
adapter = pb.adapters.create(
config=FinetuningConfig(
base_model="mistral-7b",
epochs=1, # default: 3
rank=8, # default: 16
learning_rate=0.0001 # default: 0.0002
target_modules=["q_proj", "v_proj", "k_proj"], # default: None (infers [q_proj, v_proj] for mistral-7b)
),
dataset=dataset,
repo=repo,
description="changing epochs, rank, and learning rate"
)

Each adapter version has its own page in the Adapters UI.

Monitor Progress

You can always return to your fine-tuning job in the SDK.

# Get adapter, blocking call if training is still in progress
adapter = pb.adapters.get("news-summarizer-model/1")
adapter

Use Your Adapter

Start by prompting your adapter using a serverless endpoint (no deploying necessary) and then once you're happy with your adapter's performance, create a dedicated deployment for production use.

Serverless Endpoints ($/token)

Serverless endpoints are a shared resource we offer for getting started, experimentation, and fast iteration. If your base model is one that is hosted as a serverless endpoint, you can use your fine-tuned model instantly by utilizing LoRAX:

input_prompt="<s>[INST] The following passage is content from a news report. Please summarize this passage in one sentence or less. \n Passage: Memray is a memory profiler for Python. It can help developers discover the cause of high memory usage, find memory leaks, and find hotspots in code that cause a lot of allocations. Memray can be used both as a command-line tool or as a library. \n Summary: [/INST] "

lorax_client = pb.deployments.client("mistral-7b")
print(lorax_client.generate(input_prompt, adapter_id="news-summarizer-model/1", max_new_tokens=100).generated_text)

The first line returns a LoRAX client that we can use for prompting. The second line calls generate while passing in the adapter repo name and version to prompt our fine-tuned model.

We can compare the fine-tuned model to the base model by calling generate without adapter_id:

print(lorax_client.generate(input_prompt, max_new_tokens=100).generated_text)

Dedicated Deployments ($/GPU-hour)

Once you're ready for production, deploy a private instance of the base model for greater reliability, control, and no rate limiting. LoRAX enables you to serve an unlimited number of adapters on a single base model deployment.

Predibase officially supports serving these models. Note that dedicated deployments are always on by default. (To change this, modify the cooldown_time parameter.) For the base_model, you'll need the Huggingface path, which can be found here for the models we officially support.

# Deploy
pb.deployments.create(
name="my-mistral-7b",
config=DeploymentConfig(
base_model="mistral-7b",
# cooldown_time=3600 # Value in seconds, defaults to 0 which means deployment is always on
)
# description="", # Optional
)

# Prompt
input_prompt="<s>[INST] The following passage is content from a news report. Please summarize this passage in one sentence or less. \n Passage: Memray is a memory profiler for Python. It can help developers discover the cause of high memory usage, find memory leaks, and find hotspots in code that cause a lot of allocations. Memray can be used both as a command-line tool or as a library. \n Summary: [/INST] "
lorax_client = pb.deployments.client("my-mistral-7b")
print(lorax_client.generate(input_prompt, adapter_id="news-summarizer-model/1", max_new_tokens=100).generated_text)

Learn more about dedicated deployments.

Download Model

To download your model, you can do so by:

pb.adapters.download("news-summarizer-model/1")

Note that the exported model files will contain only the adapter weights, not the full LLM weights.

Next Steps

  • Try training with your own dataset and use case
  • Try training with a larger model (i.e. mixtral-8x7b) to compare performance