Skip to main content

Quickstart: LLM Fine-Tuning

This quickstart will show you how to prompt, fine-tune, and deploy LLMs in Predibase. We'll be following a code generation use case where our end result will be a fine-tuned Llama 2 7B model that takes in natural language as input and returns code as output.

Open in Colab Notebook

info

For Python SDK users, we'd recommend using an interactive notebook environment, such as Jupyter or Google Colab.

Setup

Make sure you've installed the SDK and configured your API token.

If you're using the Python SDK, you'll first need to initialize your PredibaseClient object. All SDK examples below will assume this has already been done.

from predibase import PredibaseClient

# If you've already run `pbase login`, you don't need to provide any credentials here.
#
# If you're running in an environment where the `pbase login` command-line tool isn't available,
# you can also set your API token using `pc = PredibaseClient(token="<your token here>")`.
pc = PredibaseClient()

Deploy a pretrained LLM

info

Only VPC and Premium SaaS users with the Admin role will be able to deploy a pretrained LLM. Predibase Cloud users will have access to shared deployments without the need to manage any deployments themselves.

llm = pc.LLM("hf://meta-llama/Llama-2-7b-hf")
llm_deployment = llm.deploy(deployment_name="llama-2-7b").get()

Prompt a deployed LLM

For our code generation use case, let's first see how Llama 2 7B performs out of the box.

The first line is where we specify which deployed LLM we intend to query. If you are in the Predibase SaaS environment, you have a few shared LLMs available to you, including Llama 2 7B. If you are in a VPC environment, you'll need to deploy an LLM before you can query it.

llm_deployment = pc.LLM("pb://deployments/llama-2-7b")
result: list = llm_deployment.prompt("""
Below is an instruction that describes a task, paired with an input
that may provide further context. Write a response that appropriately
completes the request.

### Instruction: Write an algorithm in Java to reverse the words in a string.

### Input: The quick brown fox

### Response:
""", options={"max_new_tokens": 256})
print(result[0].response)

Fine-tune a pretrained LLM

Next we'll upload a dataset and fine-tune to see if we can get better performance.

The Code Alpaca dataset is used for fine-tuning large language models to follow instructions to produce code from natural language and consists of the following columns:

  • instruction that describes a task
  • input when additional context is required for the instruction
  • the expected output

Download the Code Alpaca dataset

For the sake of this quickstart, we've created a version of the Code Alpaca dataset with fewer rows so that the model trains significantly faster.

wget https://predibase-public-us-west-2.s3.us-west-2.amazonaws.com/datasets/code_alpaca_800.csv

Upload the dataset to Predibase and start fine-tuning

The fine-tuning job should take around 35-45 minutes total. Queueing time depends on how quickly we're able acquire resources and what other jobs might be ahead in the queue. The training time itself should be around 25-30 minutes. As the model trains, you can receive updated metrics in your notebook or terminal. You can also see metrics and visualizations in the Predibase UI.

# Upload the dataset to Predibase (estimated time: 2 minutes due to creation of Predibase dataset with dataset profile)
# If you've already uploaded the dataset before, you can skip uploading and get the dataset directly with
# "dataset = pc.get_dataset("code_alpaca_800", "file_uploads")".
dataset = pc.upload_dataset("code_alpaca_800.csv")

# Define the template used to prompt the model for each example
# Note the 4-space indentation, which is necessary for the YAML templating.
prompt_template = """Below is an instruction that describes a task, paired with an input
that may provide further context. Write a response that appropriately
completes the request.

### Instruction: {instruction}

### Input: {input}

### Response:
"""

# Specify the Huggingface LLM you want to fine-tune
# Kick off a fine-tuning job on the uploaded dataset
llm = pc.LLM("hf://meta-llama/Llama-2-7b-hf")
job = llm.finetune(
prompt_template=prompt_template,
target="output",
dataset=dataset,
# repo="optional-custom-model-repository-name"
)

# Wait for the job to finish and get training updates and metrics
model = job.get()

Download your fine-tuned LLM

info

In this quickstart, we're running adapter-based fine-tuning, so the exported model files will contain only the adapter weights, not the full LLM weights.

model.download(name="llm.zip", location="/path/to/folder")

Prompt your fine-tuned LLM

Batch Inference

info

To run batch inference, you'll need an appropriate engine. Only users with the Admin role will be able to create new engines. Users with the User role can see existing engines on the Engines page in the Predibase UI.

Run batch inference using the predict method on a DataFrame. Note: Engines will automatically start up when used and may take around 10 minutes to initialize.

# Create an engine suitable for inference. An appropriate template can be selected
# using `pc.get_engine_templates()`.
eng = pc.create_engine("llm-batch-engine", template="gpu-a10g-small", auto_suspend=1800, auto_resume=True)

import pandas as pd

test_df = pd.DataFrame.from_dict(
{"instruction": ["Write an algorithm in Java to reverse the words in a string."],
"input": ["The quick brown fox jumped over"]})
results = model.predict(targets="output", source=test_df, engine=eng)

pd.set_option('display.max_colwidth', None)
print(results)

Deploy for Real-Time Inference

Deploy your fine-tuned LLM for real-time inference. Once deployed, you can use the prompt method in the SDK to query your model or use the Query Editor in the Predibase UI. Deploying the fine-tuned LLM from this Quickstart guide should take around 10 minutes.

info

Only VPC and Premium SaaS users with the Admin role will be able to deploy a fine-tuned LLM.

finetuned_llm = model.deploy("llama-2-7b-finetuned").get()
result: list = finetuned_llm.prompt("""
Below is an instruction that describes a task, paired with an input
that may provide further context. Write a response that appropriately
completes the request.

### Instruction: Write an algorithm in Java to reverse the words in a string.

### Input: The quick brown fox

### Response:
""", options={"max_new_tokens": 256})
print(result[0].response)