Source code for training_model.grpo_train

"""File for training using grpo method"""
import json
import logging
import os
from typing import Callable, List, Optional, Tuple

from datasets import Dataset
from hydra.utils import get_original_cwd
from omegaconf import DictConfig
from peft import PeftModel
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from trl import GRPOConfig, GRPOTrainer


[docs] def reward_function(prompts: List[str], completions: List[str], **kwargs) -> list: """Compute rewards for GRPO training based on action matching. Args: prompts (list): List of prompts for trl consistency completions (list): List of model-generated completions kwargs (dict): Dataset row containing 'correct_actions' Returns: list: List of reward values for each completion """ row = kwargs.get("row") if isinstance(row, str): row = json.loads(row) correct_actions = row["Content"]["Action"] rewards = [] for completion in completions: try: completion_dict = json.loads(completion) generated_actions = completion_dict.get("Content", {}).get("Action", []) if generated_actions == correct_actions: rewards.append(1.0) # Correct action match else: rewards.append(0.0) # Incorrect action except (json.JSONDecodeError, KeyError): rewards.append(-1.0) # Penalize invalid completions return rewards
[docs] def prepare_grpo_data( cfg: DictConfig, ) -> Tuple[Dataset, Dataset]: """Prepare datasets for GRPO training with prompts and correct actions. Args: cfg (DictConfig): Configuration object Returns: Tuple[Dataset, Dataset]: Tuple containing train and validation datasets """ data_dir = os.path.join(get_original_cwd(), cfg.paths.data_dir) with open(os.path.join(data_dir, cfg.grpo.val_data), "r", encoding="utf-8") as file: test_dataset = json.load(file) with open( os.path.join(data_dir, cfg.grpo.train_data), "r", encoding="utf-8" ) as file: train_dataset = json.load(file) def process_dataset(dataset: dict): processed_data = [] new_dataset = dict(dataset) new_dataset.pop("system") for example in new_dataset["examples"]: logging.debug(example) prompt_dict = example["prompt"] history = prompt_dict["History"][0] # First system message available_actions = prompt_dict["AvailableActions"] user_input = prompt_dict["UserInput"] prompt_str = ( f"{history}\nAvailableActions: {available_actions}\nUser: {user_input}" ) answer_dict = example["answer"] correct_action = answer_dict["Content"]["Action"] correct_actions = [correct_action] # Store as a list for reward function processed_data.append( {"prompt": prompt_str, "correct_actions": correct_actions} ) return Dataset.from_list(processed_data) train_data = process_dataset(train_dataset) val_data = process_dataset(test_dataset) return train_data, val_data
[docs] def grpo_train( model: AutoModel | PeftModel | PreTrainedModel, tokenizer: AutoTokenizer | PreTrainedTokenizer, cfg: DictConfig, data_preparing_func: Optional[Callable], reward_func: Callable = reward_function, ) -> int: """Execute GRPO training pipeline. Args: model (AutoModel): LLM model tokenizer (AutoTokenizer): LLM tokenizer cfg (DictConfig): Configuration object data_preparing_func (Callable): Function used to prepare the data. Should return Tuple[Dataset, Dataset]: Tuple containing train and validation datasets reward_func (Callable): Reward function for the grpo Returns: int: Number of global training steps completed """ if data_preparing_func is None: train_data, val_data = prepare_grpo_data(cfg) else: train_data, val_data = data_preparing_func( cfg, tokenizer, should_add_prompt=True ) logging.info("GRPO data prepared") logging.debug(type(cfg.grpo.max_completion_length)) if cfg.grpo.max_completion_length == "None": cfg.grpo.max_completion_length = tokenizer.model_max_length grpo_config = GRPOConfig( output_dir=cfg.model.new_model, per_device_train_batch_size=cfg.training.per_device_train_batch_size, gradient_accumulation_steps=cfg.training.gradient_accumulation_steps, learning_rate=cfg.training.learning_rate, num_train_epochs=cfg.training.num_train_epochs, logging_steps=cfg.training.logging_steps, max_completion_length=cfg.grpo.max_completion_length, eval_strategy="steps", eval_steps=cfg.training.eval_steps, warmup_steps=cfg.training.warmup_steps, fp16=cfg.training.fp16, bf16=cfg.training.bf16, weight_decay=cfg.training.weight_decay, gradient_checkpointing=cfg.training.gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, report_to="wandb", save_total_limit=cfg.training.save_total_limit, load_best_model_at_end=cfg.training.load_best, num_generations=cfg.grpo.num_generations, ) trainer = GRPOTrainer( model=model, args=grpo_config, train_dataset=train_data, eval_dataset=val_data, processing_class=tokenizer, reward_funcs=reward_func, ) trainer.train() logging.info("GRPO training completed") global_steps = trainer.state.global_step return global_steps