Predibase supports reinforcement learning through Group Relative Policy Optimization (GRPO), which enables direct optimization of model behavior using programmable reward functions. Unlike traditional RL approaches like PPO and DPO that require labeled preference data, GRPO allows models to develop generalized strategies for solving tasks without extensive human feedback data collection.

Reinforcement fine-tuning is only supported in Developer and Enterprise Tiers. Users in the Free Tier will need to upgrade or book a demo with our team.

For more detailed information about GRPO and best practices, we highly recommend checking out our comprehensive user guide.

Quick Start

To get started with reinforcement fine-tuning:

  1. Prepare a dataset with a prompt field
  2. Define your reward functions
  3. Start training with GRPO

For a deeper understanding of how GRPO works and when to use it, see our reinforcement learning guide.

Dataset Requirements

Your dataset must include prompt field containing the input text. Optionally, you can include additional columns that can be accessed within reward functions.

Define Reward Functions

What are reward functions?

Reward functions are Python functions that evaluate the quality of a model’s output by assigning a numerical score. They serve as the training signal that guides the model to improve its responses. Each reward function can focus on different aspects of the output, such as following a specific format, maintaining factual accuracy, or meeting task-specific requirements.

All reward functions must follow this Python function signature:

def reward_fn(prompt: str, completion: str, example: dict[str, str]) -> float

Parameters:

  • prompt: Input prompt from dataset
  • completion: Model’s generated output (generated by the model during training)
  • example: Dictionary containing all fields from the dataset row, including any additional columns that can be used in reward calculations.

Return value:

  • Numeric score (can be int or float, 0-1 range recommended)

Non-numeric return values default to a reward of 0. Always return valid numeric values to properly guide training.

Example Reward Functions

Here are two example reward functions that demonstrate common patterns in GRPO reward functions:

  1. A format reward function that ensures the model’s output follows a specific structure
  2. A task specific reward function that can grade the ability of the model to complete your desired specific task. For example, if you are training a model to solve math word problems, you can use a task specific reward function to grade the accuracy of the model’s solution

Let’s say you are training a model to play countdown - a game where the model must create an equation using given numbers to reach a target number, using only basic math operations (+, -, *, /).

Example input prompt for this task:

<|im_start|>system
You are a helpful assistant. You first think about the reasoning process step by step
and then provide the user with an answer.<|im_end|>
<|im_start|>user
Using the numbers [17, 64, 63, 26], create an equation that equals 44. You can use
basic arithmetic operations (+, -, *, /) and parentheses, and each number can only
be used once. Show your work in <think> </think> tags. And return the final equation
and answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>
<|im_start|>assistant
Let me solve this step by step.

For this task, you can define a task specific reward function that grades the accuracy of the model’s solution.

Format Reward Function

def format_reward_func(prompt: str, completion: str, example: dict[str, str]) -> float:
    # Imported packages must be inside each reward function
    import re

    try:
        # Check if the format matches expected pattern:
        # <think> content </think> followed by <answer> content </answer>
        regex = (
            r"^<think>\s*([^<]*(?:<(?!/?think>)[^<]*)*)\s*<\/think>\n"
            r"<answer>\s*([\s\S]*?)\s*<\/answer>$"
        )
        match = re.search(regex, completion, re.DOTALL)
        if match is not None and len(match.groups()) == 2:
            return 1.0
        return 0.0

    except Exception:
        return 0.0

Equation Reward Function

def equation_reward_func(prompt: str, completion: str, example: dict[str, str]) -> float:
    # Imported packages must be inside each reward function
    import re
    import ast

    try:
        match = re.search(r"<answer>\s*([\s\S]*?)\s*<\/answer>", completion)
        if not match:
            return 0.0

        # Extract and validate equation
        equation = match.group(1).strip()
        if not re.match(r'^[\d+\-*/().\s]+$', equation):
            return 0.0

        # Extract and validate numbers
        used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
        nums = ast.literal_eval(example["nums"]) if isinstance(example["nums"], str) else example["nums"]
        if sorted(used_numbers) != sorted(nums):
            return 0.0

        # Evaluate equation and check result
        result = eval(equation, {"__builtins__": None}, {})
        if abs(float(result) - float(example["target"])) < 1e-5:
            return 1.0

        return 0.0

    except Exception:
        return 0.0

Start Training

Use the Predibase SDK to configure and start training:

from predibase import GRPOConfig, RewardFunctionsConfig

adapter = pb.adapters.create(
    config=GRPOConfig(
        base_model="qwen2-5-7b-instruct",
        reward_fns=RewardFunctionsConfig(
            functions={
                "format": format_reward_func,
                "answer": equation_reward_func,
            },
        ),
        train_steps=200,  # Minimum recommended: 200
    ),
    dataset="training_dataset",
    repo=repo,
)

Monitor Training

You can find more detailed information in our comprehensive user guide.

Reward Graphs

Monitor training progress through:

  • total_reward: Average combined reward across all reward functions
  • total_reward_std: Standard deviation indicating performance consistency
  • Individual reward function graphs showing learning progress for each objective in the Predibase UI

When interpreting these graphs:

  • Look for an overall upward trend in rewards
  • Expect 40-50 training steps before clear improvements
  • Format-related rewards often improve before complex task-specific rewards
  • High variance early in training should decrease over time

Reward Logs

Use the Reward Logs tab to:

  • Track training progress
  • Debug reward function issues
  • Monitor performance problems
  • Investigate reward score behavior

Completions Tab

Compare model outputs during training:

  • View completions side-by-side
  • See reward scores for each completion
  • Track improvements across epochs
  • Detect reward hacking
  • Validate output formatting

Using External Libraries

Import additional packages in reward functions:

from predibase import RewardFunctionsConfig, RewardFunctionsRuntimeConfig

def my_reward_function(prompt, completion, example) -> float:
    import my_pkg
    return my_pkg.score(...)

cfg = RewardFunctionsConfig(
    runtime=RewardFunctionsRuntimeConfig(
        packages=[
            "mypkg", # Add any packages you need here
        ]
    ),
    functions={
        "my_reward": my_reward_function,
    },
)

Update Reward Functions

You can also modify reward functions during training. This is useful if you want to update a reward function to include new criteria or change the way it is evaluated, which is useful both to fix issues and to improve learning as it starts to saturate.

# Get current config
cfg = pb.adapters.get_config("myrepo/1")

# Update existing or add new functions
cfg.reward_fns["my_reward"] = my_reward_function_v2

# Apply updates
pb.adapters.update_config("myrepo/1", cfg)

Next Steps

For more detailed information about GRPO and best practices, we highly recommend checking out our comprehensive user guide.

To see example use cases and implementations: