Skip to main content

Group Relative Policy Optimization (GRPO)

Overview

With the advent of reasoning models like DeepSeek R1, reinforcement learning (RL) has become more effective and important for fine-tuning models. This is because traditional post-training methods like supervised fine-tuning don't work as well for reasoning models as they do for non-reasoning models. While reasoning models are great out of the box, they can be significantly improved through reinforcement learning on downstream tasks.

Predibase now supports reinforcement learning through Group Relative Policy Optimization (GRPO), an innovative RL method introduced in DeepSeek's groundbreaking R1 paper. Unlike traditional RL approaches such as Reinforcement Learning with Human Feedback (RLHF), which require collecting labeled preference data to train reward models, GRPO enables direct optimization of model behavior using programmable reward functions. This allows models to develop generalized strategies for solving tasks without the need for any (or extensive) human feedback data collection.

How does GRPO work?

GRPO Training Loop

GRPO follows an iterative training process:

  1. Dataset Input: The process begins with creating a dataset of 10s of examples containing prompts that will be used to train the model.

  2. Models: Inside the GRPO Trainer, two components work together:

    • Frozen LLM: Acts as a reference model that maintains baseline performance
    • Trainable LLM: The model being optimized through the training process

    The reason there are two models is to maintain stability during training. The frozen LLM acts as an anchor point, ensuring the trainable LLM doesn't deviate too drastically from the original model's behavior while still optimizing for the target task. This dual-model approach helps prevent catastrophic forgetting and maintains the model's general capabilities while improving on the specific task at hand.

  3. Completion Generation: For each prompt, the training model (often called the policy model in RL literature) generates N different completions through temperature-based sampling. This sampling method introduces controlled randomness into the generation process, producing variations in the completions while maintaining coherence. By generating multiple slightly different completions for the same prompt, we create a diverse set of responses that can be evaluated in the next stage.

  4. Reward Scoring: A reward server evaluates each completion using predefined programmable reward functions, assigning scores based on the quality or correctness of the responses. These scores are summed up to produce a final score for each completion in the group. The mean and standard deviation of these scores are then used to calculate "advantages" - identifying which completions performed above average (positive advantage) and which performed below average (negative advantage) within the group. This grouping into above/below average completions provides clear learning signals to the model about which patterns to reinforce and which to avoid in future generations.

  5. Iterative Improvements: This process repeats continuously during training, allowing the model to progressively learn better strategies for generating high-quality completions. Through each iteration:

    • The model refines its approach based on the reward signals, learning which patterns lead to higher rewards
    • The policy model's weights are updated to favor strategies that produced better completions
    • The latest version of the policy model is used to generate completions for the next batch of prompts
    • The proximity to the reference model helps maintain stability while improving performance

This iterative cycle enables the model to discover and reinforce effective reasoning patterns while avoiding behaviors that lead to lower rewards.Over time, the model improves through direct optimization of task-specific metrics without requiring explicit labeled data or human feedback.

We've optimized this process by leveraging Low-Rank Adaptation (LoRA) instead of updating the full model weights. LoRA adapters allow us to efficiently fine-tune models by only training a small set of parameters while keeping the base model frozen. We use high-rank LoRA adapters by default to ensure sufficient model capacity for learning complex reasoning patterns. The training process is powered by our production-grade multi-LoRA serving infrastructure, LoRAX, which enables rapid iteration between training steps by efficiently updating and serving the policy model within the training loop.

From our experiments, we've found that GRPO is not limited to showing improvements only on reasoning models - it works remarkably well even for non-reasoning models like Qwen-2.5 and Llama-3. This makes it a versatile approach for enhancing model performance across different model architectures and capabilities.

When to use Reinforcement Fine-Tuning

The flowchart below helps determine whether to use Reinforcement Fine-Tuning (RFT), Supervised Fine-Tuning (SFT), or RLHF based on your data and task characteristics.

RFT vs SFT

How To Use GRPO in Predibase

Prepare Your Dataset

For RFT, you need, at minimum, a text dataset with a prompt field. Your dataset can optionally have other columns as well. When provided, these columns are accessible within your defined reward functions if you need to access these fields and their values.

Defining Reward Functions

After you have uploaded your dataset, you can define reward function(s). These reward functions will be used to score your model's generations during training.

All reward functions must follow this function signature:

def reward_fn(prompt: str, completion: str, example: dict[str, str]) -> int

The prompt is a prompt from your dataset. The completion is one of N of the model's generated outputs for the given prompt. The example is the original data sample from your dataset represented as a dictionary. If you define more than one reward function, you should make sure you give each reward function a unique function name.

Example Reward Functions

Here are two example reward functions that demonstrate common patterns in GRPO reward functions:

  1. A format reward function that ensures the model's output follows a specific structure
  2. An equation reward function that validates mathematical correctness for a generated math expression

The example below is from a math problem-solving task where the model needs to:

  • Generate equations using specific numbers to reach a target value
  • Format its response with both reasoning (<think> tags) and final answer (<answer> tags) - this is specified in the system prompt.

Example input prompt for this task:

<|im_start|>system
You are a helpful assistant. You first think about the reasoning process step by step
and then provide the user with an answer.<|im_end|>
<|im_start|>user
Using the numbers [17, 64, 63, 26], create an equation that equals 44. You can use
basic arithmetic operations (+, -, *, /) and parentheses, and each number can only
be used once. Show your work in <think> </think> tags. And return the final equation
and answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>
<|im_start|>assistant
Let me solve this step by step.
Format Reward Function
def format_reward_func(prompt: str, completion: str, example: dict[str, str]) -> int:
# Imported packages must be inside each reward function
import re

try:
# Check if the format matches expected pattern:
# <think> content </think> followed by <answer> content </answer>
regex = (
r"^<think>\s*([^<]*(?:<(?!/?think>)[^<]*)*)\s*<\/think>\n"
r"<answer>\s*([\s\S]*?)\s*<\/answer>$"
)
match = re.search(regex, completion, re.DOTALL)
if match is not None and len(match.groups()) == 2:
return 1.0
return 0.0
except Exception:
return 0.0

This reward function checks if the model's output follows the expected format we've asked for of providing step-by-step reasoning within <think> tags followed by the final answer within <answer> tags, returning 1.0 if the format is correct and 0.0 otherwise.

Format reward functions are important for two key reasons:

  1. They encourage chain-of-thought reasoning by requiring the model to explicitly show its work, which often leads to more accurate solutions.

  2. They enforce consistent output structures that can be reliably parsed for downstream validation and processing. In this example, the format allows us to extract the final answer from the <answer> tags to verify its correctness.

When combined with other reward functions (like the equation validator), format rewards help create a multi-objective optimization that balances both structural compliance and task-specific correctness.

Equation Reward Function
def equation_reward_func(prompt: str, completion: str, example: dict[str, str]) -> float:
# Imported packages must be inside each reward function
import re
import ast

try:
match = re.search(r"<answer>\s*([\s\S]*?)\s*<\/answer>", completion)
if not match:
return 0.0

# Extract and validate equation
equation = match.group(1).strip()
if not re.match(r'^[\d+\-*/().\s]+$', equation):
return 0.0

# Extract and validate numbers
used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
nums = ast.literal_eval(example["nums"]) if isinstance(example["nums"], str) else example["nums"]
if sorted(used_numbers) != sorted(nums):
return 0.0

# Evaluate equation and check result
result = eval(equation, {"__builtins__": None}, {})
if abs(float(result) - float(example["target"])) < 1e-5:
return 1.0

return 0.0

except Exception:
return 0.0

This reward function evaluates the actual correctness of the task:

  1. Extracts the equation from within the <answer> tags
  2. Validates that the equation only contains valid mathematical symbols and numbers
  3. Checks that the equation uses exactly the numbers provided in the input (no more, no less) since that is the required constraint in our problem
  4. Evaluates the equation and verifies it equals the target value
  5. Returns 1.0 if all checks pass, 0.0 otherwise
Kicking Off Your GRPO Training Job

Once you have your reward functions defined, you can pass them in through the GRPOConfig in the Predibase SDK.

adapter = pb.adapters.create(
config=GRPOConfig(
base_model="qwen2.5-7b-instruct",
reward_fns=RewardFunctionsConfig(
functions={
"format": RewardFunction.from_callable(format_reward_fn),
"answer": RewardFunction.from_callable(answer_reward_fn),
},
)
),
dataset="countdown",
repo=repo,
description="..."
)

Reward Function FAQs

How do I pass additional columns to my reward functions?

You can add additional columns to your dataset beyond the prompt column. These columns will be passed to your reward functions in the example argument as a dictionary with the key being the column name and the value being the column value. Please note that all values in the example dictionary are strings, so you will need to convert them to the appropriate type depending on your use case.

Do the rewards need to be binary?

Nope! You can write more intricate reward functions that return a score between 0 and 1 (such as for similarity scores or partial credit scoring). You can take a look at section 4.1 of our blog post here for an example.

What if I need to use an external library in my reward function?

You can import any libraries in your reward function. However, please note that the import must be inside the function definition. If external packages need to be installed for your reward functions to work, you can specify them as follows:

def my_reward_function(prompt, completion, example) -> float:
import my_pkg
my_pkg.score(...)

cfg = RewardFunctionsConfig(
runtime=RewardFunctionsRuntimeConfig(
packages=[
"mypkg",
]
),
functions={
"my_reward": RewardFunction.from_callable(my_reward_function)
},
)

How are data types handled in the example dictionary?

The example dictionary passed to reward functions contains string values for all columns, regardless of their original data type. This means you'll need to convert non-string data types back to their original form. For example:

  • Lists, dictionaries, sets and other data structures can be converted using ast.literal_eval()
  • Numbers can be converted using float() or int()
  • Booleans can be converted using bool()

Any additional tips?

It is usually good to have a format reward function so the model learns to respond in the format you want and a correctness reward function at the very least to see if the answer/expected output is good.