Skip to main content

Fine-Tuning Visual Language Models (Beta)

What are Visual Language Models?

Visual Language Models (VLMs) extend the capabilities of traditional Large Language Models (LLMs) by incorporating visual inputs alongside text. While LLMs generate or comprehend text based on purely linguistic context, VLMs take this a step further by jointly processing images and text. Essentially, they transform raw pixel data into a form that can be aligned with text embeddings, enabling the model to reason across both modalities.

The core architecture is typically based on transformers, similar to those used in LLMs. However, instead of just attending to sequences of tokens, VLMs also attend to visual features extracted from images using vision encoders (like CNNs, Vision Transformers, SigLIP, etc.). This enables them to understand complex associations between what they "see" in an image and what they "read" in a caption or query.

Visual Language Model Architecture

Common Use-Cases

  1. Visual Recognition & Image Reasoning: Identifying objects, activities, or structured elements within images, enhancing tasks like document analysis and product categorization.
  2. Image Captioning: Generating descriptive captions for accessibility, social media, and content tagging.
  3. Visual Question Answering (VQA): Answering questions about an image's content, useful in customer support, education, and smarter image search.
  4. Multimodal Content Generation: Creating text based on visual prompts or combining text and images for synthetic datasets.
  5. Image-Based Search: Enabling more intuitive search engines by combining visual and text-based queries.

Fine-tune

Dataset Preparation

To fine-tune VLMs on Predibase, your dataset must contain 3 columns:

  • prompt
  • completion
  • images

In particular, Predibase currently only supports one image input per prompt-completion pair and requires this image to be a bytestream. We're working on adding support for multiple image inputs, as well as multi-turn chat support for chat based datasets.

If you have a dataset that resembles the raw dataset below, you can format it to the Predibase compatible fine-tuning format using the code below:

Raw dataset:

Prompt: "What is this a picture of?"
Completion: "A dog"
Image: <URL_TO_IMAGE>

To apply the chat template, run the following code. Note that there must be an image token in the text input, and we format the dataset as follows in order to ensure that the image token precedes the text:

Open In Colab

from transformers import AutoProcessor
from PIL import Image

import pandas as pd
import requests
import io

processor = AutoProcessor.from_pretrained(model_id)

def to_bytes(image_url):
if isinstance(image_url, list):
image_url = image_url[0]

if isinstance(image_url, str):
image = Image.open(requests.get(image_url, stream=True).raw)
image = Image.new("RGB", image.size, (255, 255, 255))
else:
image = image_url

img_byte_array = io.BytesIO()
image.save(img_byte_array, format='PNG')
return img_byte_array.getvalue()

def format_row(row):
prompt = row["prompt"]
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": processor.image_token + prompt,
},
]
}
]
try:
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
else:
formatted_prompt = processor.image_token + prompt + processor.tokenizer.eos_token
completion = row["completion"]
image_bytes = to_bytes(row["images"])

return {"prompt": formatted_prompt, "completion": completion, "images": image_bytes}

train_dataset: pd.DataFrame = train_dataset.apply(format_row, axis=1, result_type="expand")

Once you run this code, your text inputs should each contain an image token (this varies depending on the model, but it should look something like <|image|>), and images should contain raw byte strings.

You can save this dataset as a csv via dataset.to_csv() and upload this to Predibase. Since the feature is currently in Beta, we only allow uploading relatively small datasets (~1000 rows) for VLM fine-tuning. We're working on supporting larger datasets for VLM fine-tuning in the coming weeks.

Upload to Predibase

Due to the size of VLM datasets, we recommend uploading datasets to S3 first before connecting them to Predibase. You can add your dataset via S3 using the UI. With S3, we can support datasets up to 1 GB.

If you'd prefer to upload it as a file, run:

dataset = pb.datasets.from_file("{Path to local file}", name="doc_vqa_test")

Training

Fine-tuning a VLM in either the UI or the SDK follow the same process as fine-tuning any LLM in Predibase. The only differences include:

  1. You must select a VLM as the base model. See supported VLMs.
  2. Your dataset must be formatted correctly as shown above.

Note that we currently only support instruction tuning for VLMs (not completion or chat completion).

We also currently only support LoRA based fine-tuning for VLMs. Stay tuned for Turbo and Turbo LoRA support!

Inference

Adding images to your prompt

To run inference, how you include images in your prompts depends on your image format:

URLs

prompt = "![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png) What is this a picture of?"

Image Files

With image files, you have to encode them into base64 first:

import base64

image_file_path = "INSERT_FILE_PATH_HERE"

with open(image_file_path, "rb") as f:
byte_string = base64.b64encode(f.read()).decode()

# You may need to change this image format depending on how the byte stream is encoded
prompt = f"![](data:image/png;base64,{byte_string}) What is this an image of?"

Byte Strings

You also must encode byte strings into base64 first before passing them to the model:

import base64

byte_string = "INSERT_BYTE_STRING_HERE"
encoded_byte_string = base64.b64encode(byte_string).decode()

prompt = f"![](data:image/png;base64,{encoded_byte_string}) What is this an image of?"

As shown in these examples, we highly recommend that you insert the image before the text in your requests.

Basic prompting

To prompt your fine-tuned adapter, you may use one of our shared endpoints or create a private deployment.

lorax_client = pb.deployments.client("llama-3-2-11b-vision-instruct")
print(lorax_client.generate(prompt, adapter_id="my-repo/1", max_new_tokens=100).generated_text)

Run on an evaluation dataset

If you have an evaluation dataset in the same format as the formatted train dataset above, you can test your model with the code below:

import pandas as pd
from predibase import Predibase
import base64
import ast
import os

pb = Predibase(api_token="...")
dataset_path = "{PATH_TO_DATASET}"
deployment_name = "{NAME_OF_DEPLOYMENT_TO_QUERY}"
adapter_id = "{ADAPTER_REPO}/{ADAPTER_VERSION}"
output_file = "{PATH_TO_OUTPUT_FILE}"

base_img_str = "![](data:image/png;base64,{})"

dataset = pd.read_csv(dataset_path)
dataset.rename(columns={"completion": "ground_truth"}, inplace=True)

# Grab the unformatted prompt
dataset["raw_prompt"] = dataset["prompt"].apply(lambda x: x.split("<|image|>")[1].split("<|eot_id|>")[0])

# Format image data to base64 string
dataset['images'] = dataset['images'].apply(lambda x: ast.literal_eval(x))
dataset['images'] = dataset['images'].apply(lambda x: base64.b64encode(x).decode('utf-8'))
dataset['images'] = dataset['images'].apply(lambda x: base_img_str.format(x))

# Replace image tag <|image|> in prompt column with correspond value from the 'images' column
for index, row in dataset.iterrows():
dataset.at[index, 'prompt'] = row['prompt'].replace("<|image|>", row['images'])

dataset.drop(columns=['images'], inplace=True)

# Initialize Predibase client
client = pb.deployments.client(deployment_name, force_bare_client=True)

# Get adapter completions
adapter_completions = []
for index, row in dataset.iterrows():
try:
completion = client.generate(
row['prompt'],
adapter_id=adapter_id,
max_new_tokens=128,
temperature=1
)
adapter_completions.append(completion.generated_text)
except Exception as e:
print(f"Failed at index {index}: {e}")
adapter_completions.append(None)

dataset.drop(columns=['prompt'], inplace=True)

dataset['adapter_completions'] = adapter_completions

# Reorder columns
dataset = dataset[['raw_prompt', 'ground_truth', 'adapter_completions']]

dataset.to_csv(output_file, index=False)