Skip to main content

Migrate from OpenAI

Background

For fine-tuning, Predibase currently supports datasets that are in a tabular format and contain (at least 1) input and output column.

If you're coming from OpenAI, you may already have your dataset in the chat completions format they suggest such as:

{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]}
{"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]}

We are in the process of adding native support for these chat-completion style datasets from OpenAI. While that is under development, we are providing a script below that can be used to convert your dataset into the Predibase format.

Conversion Script

If you have any questions or run into any issues, please reach out to support@predibase.com or via Discord.

"""Convert an OpenAI chat completions dataset into a Predibase-compatible format.

Usage:
python openai_to_predibase.py <jsonl_file_path>
"""
import json
import math
import os
import sys
from typing import Dict, List

import pandas as pd


def read_jsonl(filename: str) -> List[Dict[str, List[Dict[str, str]]]]:
"""Read JSONL data from file into a list of dicts.

Args:
filename: Path to the JSONL file

Returns:
A list of dicts corresponding to the JSON objects in the file.
"""
with open(filename, 'r') as f:
objs = []
for line in f.readlines():
obj = json.loads(line)
objs.append(obj)
return objs


def objs_to_df(objs: List[Dict[str, List[Dict[str, str]]]]) -> pd.DataFrame:
"""Convert a list of dicts with a standard structure into a dataframe.

Args:
objs: List of dicts, where each dict corresponds to one row of the dataset.

Returns:
A dataframe in which each row corresponds to one "messages" list, and the messages appear
in order as pairs of `messages_role_i` and `messages_content_i` columns.
"""
max_num_messages = max([len(i["messages"]) for i in objs])
num_columns = max_num_messages * 2

column_names = [
f"messages_{'role' if i%2 == 0 else 'content'}_{math.floor(i / 2)}"
for i in range(num_columns)
]

messages_table = []

for obj in objs:
row = []

# Append the "role" and "content" entries of the message to the row.
for message in obj["messages"]:
row.append(message["role"])
row.append(message["content"])

# If there are fewer messages than the max, fill in the remaining
# columns of this row with null values.
missing_entries = num_columns - len(row)
if missing_entries > 0:
row += [None for _ in range(missing_entries)]

messages_table.append(row)

return pd.DataFrame(messages_table, columns=column_names)


def write_csv(filename: str, df: pd.DataFrame):
"""Write a dataframe to CSV."""
base_filename, _ = os.path.splitext(filename)
csv_filename = f"{base_filename}.csv"
df.to_csv(csv_filename, index=False)



def main():
filename = sys.argv[1]
objs = read_jsonl(filename)
df = objs_to_df(objs)
write_csv(filename, df)


if __name__ == "__main__":
main()