dreadnode.training
API reference for the dreadnode.training module.
Training module with lazy imports for heavy dependencies.
This module uses lazy loading to avoid importing torch/ray unless needed. Heavy dependencies (torch, ray, transformers, vllm) are only loaded when the user actually accesses training-related classes.
AsyncRayGRPOTrainer
Section titled “AsyncRayGRPOTrainer”AsyncRayGRPOTrainer(config: RayGRPOConfig)Async Ray-based GRPO trainer.
Uses separate GPUs for inference and training to overlap computation:
- GPU 0: vLLM inference (generates batches continuously)
- GPU 1: Training (processes batches as they arrive)
This achieves much higher throughput than the colocated version.
Requires at least 2 GPUs.
shutdown
Section titled “shutdown”shutdown() -> NoneShutdown workers.
train( prompts: Sequence[str], reward_fn: RewardFn, num_steps: int | None = None,) -> TrainingStateRun async GRPO training.
Overlaps inference and training for maximum throughput.
DPOConfig
Section titled “DPOConfig”DPOConfig( model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", tokenizer_name: str | None = None, beta: float = 0.1, label_smoothing: float = 0.0, loss_type: str = "sigmoid", max_seq_length: int = 2048, max_prompt_length: int = 512, learning_rate: float = 5e-07, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_steps: int = 1000, max_epochs: int = 1, batch_size: int = 4, gradient_accumulation_steps: int = 4, max_grad_norm: float = 1.0, ref_model_offload: bool = True, log_interval: int = 10, checkpoint_interval: int = 100, checkpoint_dir: str = "./checkpoints", seed: int = 42, trust_remote_code: bool = True,)Configuration for DPO training.
batch_size
Section titled “batch_size”batch_size: int = 4Batch size per device.
beta: float = 0.1Temperature parameter for DPO loss. Higher = more conservative updates.
checkpoint_dir
Section titled “checkpoint_dir”checkpoint_dir: str = './checkpoints'Directory for checkpoints.
checkpoint_interval
Section titled “checkpoint_interval”checkpoint_interval: int = 100Steps between checkpoints.
gradient_accumulation_steps
Section titled “gradient_accumulation_steps”gradient_accumulation_steps: int = 4Gradient accumulation steps.
label_smoothing
Section titled “label_smoothing”label_smoothing: float = 0.0Label smoothing for DPO loss (0 = no smoothing).
learning_rate
Section titled “learning_rate”learning_rate: float = 5e-07Learning rate (DPO typically uses lower LR than SFT).
log_interval
Section titled “log_interval”log_interval: int = 10Steps between logging.
loss_type
Section titled “loss_type”loss_type: str = 'sigmoid'Loss type: ‘sigmoid’ (standard DPO), ‘hinge’, ‘ipo’.
max_epochs
Section titled “max_epochs”max_epochs: int = 1Maximum training epochs.
max_grad_norm
Section titled “max_grad_norm”max_grad_norm: float = 1.0Maximum gradient norm.
max_prompt_length
Section titled “max_prompt_length”max_prompt_length: int = 512Maximum prompt length.
max_seq_length
Section titled “max_seq_length”max_seq_length: int = 2048Maximum sequence length.
max_steps
Section titled “max_steps”max_steps: int = 1000Maximum training steps.
model_name
Section titled “model_name”model_name: str = 'Qwen/Qwen2.5-1.5B-Instruct'Model name or path.
ref_model_offload
Section titled “ref_model_offload”ref_model_offload: bool = TrueKeep reference model on CPU to save GPU memory.
seed: int = 42Random seed.
tokenizer_name
Section titled “tokenizer_name”tokenizer_name: str | None = NoneTokenizer name (defaults to model_name).
trust_remote_code
Section titled “trust_remote_code”trust_remote_code: bool = TrueTrust remote code in model repository.
warmup_ratio
Section titled “warmup_ratio”warmup_ratio: float = 0.1Warmup steps as fraction of total.
weight_decay
Section titled “weight_decay”weight_decay: float = 0.01Weight decay.
DPOTrainer
Section titled “DPOTrainer”DPOTrainer( config: DPOConfig, fsdp_config: FSDP2Config | None = None, storage: Storage | None = None, checkpoint_name: str | None = None,)DPO (Direct Preference Optimization) trainer.
DPO directly optimizes the policy using preference pairs without needing a separate reward model or PPO. This makes it much simpler than RLHF.
The training process:
- Load policy model and frozen reference model
- For each preference pair (chosen, rejected):
- Compute log probabilities for both under policy and reference
- Compute DPO loss to prefer chosen over rejected
- Update policy via gradient descent
Attributes:
config–DPO configurationmodel–Training policy modelref_model–Frozen reference modeltokenizer–Tokenizer
Initialize DPO trainer.
Parameters:
config(DPOConfig) –DPO configurationfsdp_config(FSDP2Config | None, default:None) –Optional FSDP2 configurationstorage(Storage | None, default:None) –Optional storage for CAS checkpointingcheckpoint_name(str | None, default:None) –Name for checkpoints
get_model
Section titled “get_model”get_model() -> nn.ModuleGet the trained model.
save_checkpoint
Section titled “save_checkpoint”save_checkpoint() -> NoneSave training checkpoint.
train( dataset: Dataset | list[PreferencePair] | list[dict],) -> dict[str, float]Run DPO training.
Parameters:
dataset(Dataset | list[PreferencePair] | list[dict]) –Training dataset with preference pairs. Each item should have ‘prompt’, ‘chosen’, ‘rejected’ keys.
Returns:
dict[str, float]–Final training metrics
PPOConfig
Section titled “PPOConfig”PPOConfig( model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", tokenizer_name: str | None = None, reward_model_name: str | None = None, clip_ratio: float = 0.2, value_clip_ratio: float = 0.2, kl_coef: float = 0.1, kl_target: float | None = 0.01, entropy_coef: float = 0.01, gamma: float = 1.0, gae_lambda: float = 0.95, max_seq_length: int = 2048, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, learning_rate: float = 1e-06, critic_lr: float = 1e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_steps: int = 1000, batch_size: int = 8, mini_batch_size: int = 4, ppo_epochs: int = 4, gradient_accumulation_steps: int = 1, max_grad_norm: float = 1.0, ref_model_offload: bool = True, share_critic: bool = False, critic_warmup_steps: int = 0, log_interval: int = 10, checkpoint_interval: int = 100, checkpoint_dir: str = "./checkpoints", seed: int = 42, trust_remote_code: bool = True,)Configuration for PPO training.
batch_size
Section titled “batch_size”batch_size: int = 8Prompts per batch.
checkpoint_dir
Section titled “checkpoint_dir”checkpoint_dir: str = './checkpoints'Directory for checkpoints.
checkpoint_interval
Section titled “checkpoint_interval”checkpoint_interval: int = 100Steps between checkpoints.
clip_ratio
Section titled “clip_ratio”clip_ratio: float = 0.2PPO clipping ratio (epsilon).
critic_lr
Section titled “critic_lr”critic_lr: float = 1e-05Learning rate for value function (typically higher than policy).
critic_warmup_steps
Section titled “critic_warmup_steps”critic_warmup_steps: int = 0Pretrain critic for N steps before PPO (0 = no warmup).
entropy_coef
Section titled “entropy_coef”entropy_coef: float = 0.01Entropy bonus coefficient.
gae_lambda
Section titled “gae_lambda”gae_lambda: float = 0.95GAE lambda for advantage estimation.
gamma: float = 1.0Discount factor (1.0 for episodic tasks like text generation).
gradient_accumulation_steps
Section titled “gradient_accumulation_steps”gradient_accumulation_steps: int = 1Gradient accumulation steps.
kl_coef
Section titled “kl_coef”kl_coef: float = 0.1KL penalty coefficient.
kl_target
Section titled “kl_target”kl_target: float | None = 0.01Target KL divergence. If exceeded, KL coef is increased.
learning_rate
Section titled “learning_rate”learning_rate: float = 1e-06Learning rate for policy.
log_interval
Section titled “log_interval”log_interval: int = 10Steps between logging.
max_grad_norm
Section titled “max_grad_norm”max_grad_norm: float = 1.0Maximum gradient norm.
max_new_tokens
Section titled “max_new_tokens”max_new_tokens: int = 512Maximum new tokens to generate.
max_seq_length
Section titled “max_seq_length”max_seq_length: int = 2048Maximum sequence length.
max_steps
Section titled “max_steps”max_steps: int = 1000Maximum training steps.
mini_batch_size
Section titled “mini_batch_size”mini_batch_size: int = 4Mini-batch size for PPO updates.
model_name
Section titled “model_name”model_name: str = 'Qwen/Qwen2.5-1.5B-Instruct'Policy model name or path.
ppo_epochs
Section titled “ppo_epochs”ppo_epochs: int = 4Number of PPO epochs per batch of experience.
ref_model_offload
Section titled “ref_model_offload”ref_model_offload: bool = TrueKeep reference model on CPU to save GPU memory.
reward_model_name
Section titled “reward_model_name”reward_model_name: str | None = NoneReward model name or path. If None, must provide reward_fn to train().
seed: int = 42Random seed.
share_critic
Section titled “share_critic”share_critic: bool = FalseShare weights between policy and critic (adds value head to policy).
temperature
Section titled “temperature”temperature: float = 0.7Sampling temperature.
tokenizer_name
Section titled “tokenizer_name”tokenizer_name: str | None = NoneTokenizer name (defaults to model_name).
top_p: float = 0.9Top-p sampling.
trust_remote_code
Section titled “trust_remote_code”trust_remote_code: bool = TrueTrust remote code in model repository.
value_clip_ratio
Section titled “value_clip_ratio”value_clip_ratio: float = 0.2Value function clipping ratio.
warmup_ratio
Section titled “warmup_ratio”warmup_ratio: float = 0.1Warmup steps as fraction of total.
weight_decay
Section titled “weight_decay”weight_decay: float = 0.01Weight decay.
PPOTrainer
Section titled “PPOTrainer”PPOTrainer( config: PPOConfig, fsdp_config: FSDP2Config | None = None, storage: Storage | None = None, checkpoint_name: str | None = None,)PPO (Proximal Policy Optimization) trainer for RLHF.
Implements the full PPO algorithm with:
- Policy network (actor)
- Value network (critic)
- GAE advantage estimation
- Clipped surrogate objective
- KL penalty and adaptive KL coefficient
The training loop:
- Generate responses from current policy
- Compute rewards using reward model/function
- Estimate advantages with GAE
- Update policy and value networks with PPO
Attributes:
config–PPO configurationpolicy–Policy (actor) modelcritic–Value (critic) modelref_model–Frozen reference model for KL penaltytokenizer–Tokenizer
Initialize PPO trainer.
Parameters:
config(PPOConfig) –PPO configurationfsdp_config(FSDP2Config | None, default:None) –Optional FSDP2 configurationstorage(Storage | None, default:None) –Optional storage for CAS checkpointingcheckpoint_name(str | None, default:None) –Name for checkpoints
get_policy
Section titled “get_policy”get_policy() -> nn.ModuleGet the trained policy model.
save_checkpoint
Section titled “save_checkpoint”save_checkpoint() -> NoneSave training checkpoint.
train( prompts: list[str], reward_fn: Callable[[list[str], list[str]], list[float]] | None = None,) -> dict[str, float]Run PPO training.
Parameters:
prompts(list[str]) –List of training promptsreward_fn(Callable[[list[str], list[str]], list[float]] | None, default:None) –Optional reward function (prompts, completions) -> rewards. Required if reward_model_name not set in config.
Returns:
dict[str, float]–Final training metrics
RMConfig
Section titled “RMConfig”RMConfig( model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", tokenizer_name: str | None = None, value_head_hidden_size: int | None = None, value_head_dropout: float = 0.1, pooling: str = "last", max_seq_length: int = 2048, max_prompt_length: int = 512, learning_rate: float = 1e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_steps: int = 1000, max_epochs: int = 3, batch_size: int = 4, gradient_accumulation_steps: int = 4, max_grad_norm: float = 1.0, margin: float = 0.0, log_interval: int = 10, checkpoint_interval: int = 100, checkpoint_dir: str = "./checkpoints", seed: int = 42, trust_remote_code: bool = True,)Configuration for Reward Model training.
batch_size
Section titled “batch_size”batch_size: int = 4Batch size per device.
checkpoint_dir
Section titled “checkpoint_dir”checkpoint_dir: str = './checkpoints'Directory for checkpoints.
checkpoint_interval
Section titled “checkpoint_interval”checkpoint_interval: int = 100Steps between checkpoints.
gradient_accumulation_steps
Section titled “gradient_accumulation_steps”gradient_accumulation_steps: int = 4Gradient accumulation steps.
learning_rate
Section titled “learning_rate”learning_rate: float = 1e-05Learning rate.
log_interval
Section titled “log_interval”log_interval: int = 10Steps between logging.
margin
Section titled “margin”margin: float = 0.0Margin for Bradley-Terry loss (0 = no margin).
max_epochs
Section titled “max_epochs”max_epochs: int = 3Maximum training epochs.
max_grad_norm
Section titled “max_grad_norm”max_grad_norm: float = 1.0Maximum gradient norm.
max_prompt_length
Section titled “max_prompt_length”max_prompt_length: int = 512Maximum prompt length.
max_seq_length
Section titled “max_seq_length”max_seq_length: int = 2048Maximum sequence length.
max_steps
Section titled “max_steps”max_steps: int = 1000Maximum training steps.
model_name
Section titled “model_name”model_name: str = 'Qwen/Qwen2.5-1.5B-Instruct'Base model name or path.
pooling
Section titled “pooling”pooling: str = 'last'Pooling method: ‘last’ (last non-pad token), ‘mean’, ‘max’.
seed: int = 42Random seed.
tokenizer_name
Section titled “tokenizer_name”tokenizer_name: str | None = NoneTokenizer name (defaults to model_name).
trust_remote_code
Section titled “trust_remote_code”trust_remote_code: bool = TrueTrust remote code in model repository.
value_head_dropout
Section titled “value_head_dropout”value_head_dropout: float = 0.1Dropout for value head.
value_head_hidden_size
Section titled “value_head_hidden_size”value_head_hidden_size: int | None = NoneHidden size for value head. None = match model hidden size.
warmup_ratio
Section titled “warmup_ratio”warmup_ratio: float = 0.1Warmup steps as fraction of total.
weight_decay
Section titled “weight_decay”weight_decay: float = 0.01Weight decay.
RayGRPOConfig
Section titled “RayGRPOConfig”RayGRPOConfig( model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", tokenizer_name: str | None = None, num_prompts_per_step: int = 8, num_generations_per_prompt: int = 4, max_steps: int = 1000, max_epochs: int = 10, max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, learning_rate: float = 1e-06, weight_decay: float = 0.01, warmup_ratio: float = 0.1, gradient_accumulation_steps: int = 1, max_grad_norm: float = 1.0, log_interval: int = 10, eval_interval: int = 100, checkpoint_interval: int = 100, checkpoint_dir: str = "./checkpoints", seed: int = 42, vllm: VLLMConfig = VLLMConfig(), training: TrainingConfig = TrainingConfig(), loss: GRPOLossConfig = GRPOLossConfig(),)Complete configuration for Ray-based GRPO training.
This configuration controls all aspects of GRPO training:
- Model and tokenizer
- Generation (vLLM)
- Training (DeepSpeed/FSDP)
- GRPO algorithm parameters
checkpoint_dir
Section titled “checkpoint_dir”checkpoint_dir: str = './checkpoints'Directory for checkpoints.
checkpoint_interval
Section titled “checkpoint_interval”checkpoint_interval: int = 100Steps between checkpoints.
eval_interval
Section titled “eval_interval”eval_interval: int = 100Steps between evaluation.
gradient_accumulation_steps
Section titled “gradient_accumulation_steps”gradient_accumulation_steps: int = 1Gradient accumulation steps.
learning_rate
Section titled “learning_rate”learning_rate: float = 1e-06Learning rate.
log_interval
Section titled “log_interval”log_interval: int = 10Steps between logging.
loss: GRPOLossConfig = field(default_factory=GRPOLossConfig)GRPO loss configuration.
max_epochs
Section titled “max_epochs”max_epochs: int = 10Maximum training epochs.
max_grad_norm
Section titled “max_grad_norm”max_grad_norm: float = 1.0Maximum gradient norm for clipping.
max_new_tokens
Section titled “max_new_tokens”max_new_tokens: int = 512Maximum tokens to generate per completion.
max_steps
Section titled “max_steps”max_steps: int = 1000Maximum training steps.
model_name
Section titled “model_name”model_name: str = 'Qwen/Qwen2.5-1.5B-Instruct'Model name or path.
num_generations_per_prompt
Section titled “num_generations_per_prompt”num_generations_per_prompt: int = 4Number of completions to generate per prompt (G in GRPO).
num_prompts_per_step
Section titled “num_prompts_per_step”num_prompts_per_step: int = 8Number of unique prompts per training step.
seed: int = 42Random seed for reproducibility.
temperature
Section titled “temperature”temperature: float = 0.7Sampling temperature.
tokenizer_name
Section titled “tokenizer_name”tokenizer_name: str | None = NoneTokenizer name (defaults to model_name).
top_p: float = 0.9Top-p (nucleus) sampling.
train_batch_size
Section titled “train_batch_size”train_batch_size: intTotal batch size for training.
training
Section titled “training”training: TrainingConfig = field( default_factory=TrainingConfig)Distributed training configuration.
vllm: VLLMConfig = field(default_factory=VLLMConfig)vLLM inference configuration.
warmup_ratio
Section titled “warmup_ratio”warmup_ratio: float = 0.1Warmup steps as fraction of total.
weight_decay
Section titled “weight_decay”weight_decay: float = 0.01Weight decay.
to_dict
Section titled “to_dict”to_dict() -> dict[str, Any]Convert to dictionary for serialization.
RayGRPOTrainer
Section titled “RayGRPOTrainer”RayGRPOTrainer( config: RayGRPOConfig, colocate: bool = False, storage: Storage | None = None, checkpoint_name: str | None = None, callbacks: list[TrainerCallback] | None = None,)Native Ray-based GRPO trainer with colocated inference/training.
Supports two modes:
- Memory-efficient mode (default): Time-shares GPU between vLLM and training
- Lower memory, but slower due to model loading/unloading
- Fast mode (colocate=True): Keeps both models loaded
- Higher memory usage, but much faster (no reload overhead)
- Uses in-place vLLM weight updates
Example
config = RayGRPOConfig( … model_name=“Qwen/Qwen2.5-1.5B-Instruct”, … num_generations_per_prompt=4, … ) trainer = RayGRPOTrainer(config, colocate=True) # Fast mode
def reward_fn(prompts, completions): … return [1.0 if is_correct(c) else 0.0 for c in completions]
trainer.train(prompts, reward_fn)
Initialize GRPO trainer.
Parameters:
config(RayGRPOConfig) –GRPO configuration.colocate(bool, default:False) –If True, keep both vLLM and training model loaded (faster but more memory).storage(Storage | None, default:None) –Optional Storage for CAS-based checkpointing.checkpoint_name(str | None, default:None) –Name for checkpoints (defaults to sanitized model name).callbacks(list[TrainerCallback] | None, default:None) –List of TrainerCallback instances for customizing training behavior.
add_callback
Section titled “add_callback”add_callback(callback: TrainerCallback) -> NoneAdd a callback to the trainer.
remove_callback
Section titled “remove_callback”remove_callback(callback_type: type) -> NoneRemove all callbacks of a given type.
save_checkpoint_to_storage
Section titled “save_checkpoint_to_storage”save_checkpoint_to_storage( version: str | None = None,) -> LocalModel | NonePublic method to save checkpoint to CAS.
Parameters:
version(str | None, default:None) –Version string. If None, auto-increments.
Returns:
LocalModel | None–LocalModel instance if storage is configured, None otherwise.
shutdown
Section titled “shutdown”shutdown() -> NoneShutdown trainer.
train( prompts: Sequence[str], reward_fn: RewardFn, eval_prompts: Sequence[str] | None = None, num_steps: int | None = None,) -> TrainingStateRun GRPO training.
Parameters:
prompts(Sequence[str]) –Training prompts.reward_fn(RewardFn) –Function to score completions.eval_prompts(Sequence[str] | None, default:None) –Optional evaluation prompts.num_steps(int | None, default:None) –Optional number of steps (overrides config).
Returns:
TrainingState–Final training state.
RewardModelTrainer
Section titled “RewardModelTrainer”RewardModelTrainer( config: RMConfig, fsdp_config: FSDP2Config | None = None, storage: Storage | None = None, checkpoint_name: str | None = None,)Reward Model trainer using Bradley-Terry loss.
Trains a model to predict scalar rewards from preference pairs. The trained model can then be used in RLHF pipelines (PPO, GRPO, etc.).
Attributes:
config–Reward model configurationmodel–The reward model (base LLM + value head)tokenizer–Tokenizer
Initialize Reward Model trainer.
Parameters:
config(RMConfig) –Reward model configurationfsdp_config(FSDP2Config | None, default:None) –Optional FSDP2 configurationstorage(Storage | None, default:None) –Optional storage for CAS checkpointingcheckpoint_name(str | None, default:None) –Name for checkpoints
compute_rewards
Section titled “compute_rewards”compute_rewards( texts: list[str], batch_size: int = 8) -> list[float]Compute rewards for a list of texts.
Parameters:
texts(list[str]) –List of text sequencesbatch_size(int, default:8) –Batch size for inference
Returns:
list[float]–List of scalar rewards
get_model
Section titled “get_model”get_model() -> RewardModelGet the trained reward model.
get_reward_fn
Section titled “get_reward_fn”get_reward_fn() -> callableGet a reward function for use with GRPO/PPO.
Returns:
callable–A callable that takes texts and returns rewards
save_checkpoint
Section titled “save_checkpoint”save_checkpoint() -> NoneSave training checkpoint.
train(dataset: Dataset | list[dict]) -> dict[str, float]Run reward model training.
Parameters:
dataset(Dataset | list[dict]) –Training dataset with preference pairs. Each item should have ‘prompt’, ‘chosen’, ‘rejected’ keys.
Returns:
dict[str, float]–Final training metrics
SFTConfig
Section titled “SFTConfig”SFTConfig( model_name: str = "Qwen/Qwen2.5-1.5B-Instruct", tokenizer_name: str | None = None, max_seq_length: int = 2048, use_packing: bool = True, packing_efficiency_threshold: float = 0.9, learning_rate: float = 2e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_steps: int = 1000, max_epochs: int = 3, batch_size: int = 4, gradient_accumulation_steps: int = 1, max_grad_norm: float = 1.0, log_interval: int = 10, checkpoint_interval: int = 100, checkpoint_dir: str = "./checkpoints", seed: int = 42, trust_remote_code: bool = True,)Configuration for SFT training.
batch_size
Section titled “batch_size”batch_size: int = 4Batch size per device.
checkpoint_dir
Section titled “checkpoint_dir”checkpoint_dir: str = './checkpoints'Directory for checkpoints.
checkpoint_interval
Section titled “checkpoint_interval”checkpoint_interval: int = 100Steps between checkpoints.
gradient_accumulation_steps
Section titled “gradient_accumulation_steps”gradient_accumulation_steps: int = 1Gradient accumulation steps.
learning_rate
Section titled “learning_rate”learning_rate: float = 2e-05Learning rate.
log_interval
Section titled “log_interval”log_interval: int = 10Steps between logging.
max_epochs
Section titled “max_epochs”max_epochs: int = 3Maximum training epochs.
max_grad_norm
Section titled “max_grad_norm”max_grad_norm: float = 1.0Maximum gradient norm.
max_seq_length
Section titled “max_seq_length”max_seq_length: int = 2048Maximum sequence length.
max_steps
Section titled “max_steps”max_steps: int = 1000Maximum training steps.
model_name
Section titled “model_name”model_name: str = 'Qwen/Qwen2.5-1.5B-Instruct'Model name or path.
packing_efficiency_threshold
Section titled “packing_efficiency_threshold”packing_efficiency_threshold: float = 0.9Minimum packing efficiency before padding.
seed: int = 42Random seed.
tokenizer_name
Section titled “tokenizer_name”tokenizer_name: str | None = NoneTokenizer name (defaults to model_name).
trust_remote_code
Section titled “trust_remote_code”trust_remote_code: bool = TrueTrust remote code in model repository.
use_packing
Section titled “use_packing”use_packing: bool = TrueEnable sequence packing for efficiency.
warmup_ratio
Section titled “warmup_ratio”warmup_ratio: float = 0.1Warmup steps as fraction of total.
weight_decay
Section titled “weight_decay”weight_decay: float = 0.01Weight decay.
SFTTrainer
Section titled “SFTTrainer”SFTTrainer( config: SFTConfig, fsdp_config: FSDP2Config | None = None,)SFT trainer with sequence packing and FSDP2 support.
Features:
- Sequence packing for efficient training
- FSDP2 distributed training
- Gradient accumulation
- Mixed precision (bf16)
- Checkpointing
Initialize SFT trainer.
Parameters:
config(SFTConfig) –SFT configurationfsdp_config(FSDP2Config | None, default:None) –Optional FSDP2 configuration
load_checkpoint
Section titled “load_checkpoint”load_checkpoint(path: str) -> NoneLoad training checkpoint.
save_checkpoint
Section titled “save_checkpoint”save_checkpoint() -> NoneSave training checkpoint.
train( dataset: Dataset | Sequence[dict], eval_dataset: Dataset | Sequence[dict] | None = None,) -> dict[str, float]Run SFT training.
Parameters:
dataset(Dataset | Sequence[dict]) –Training dataseteval_dataset(Dataset | Sequence[dict] | None, default:None) –Optional evaluation dataset
Returns:
dict[str, float]–Final training metrics
TinkerSFTConfig
Section titled “TinkerSFTConfig”TinkerSFTConfig( base_model: str = "meta-llama/Llama-3.1-8B-Instruct", base_url: str | None = None, lora_rank: int = 16, data_dir: str = "data", train_split: str = "train", eval_split: str | None = "test", max_train_examples: int | None = None, max_eval_examples: int | None = None, max_sequence_length: int = 2048, batch_size: int = 16, gradient_accumulation_steps: int = 1, learning_rate: float = 0.0001, steps: int = 100, checkpoint_interval: int = 10, adam_beta1: float = 0.9, adam_beta2: float = 0.95, adam_eps: float = 1e-08, sample_prompt: str = "", max_new_tokens: int = 64, temperature: float = 0.0, num_samples: int = 4, skip_sample: bool = False, project: str | None = None, run_name: str | None = None, tags: list[str] = ( lambda: ["training", "sft", "tinker"] )(), seed: int = 0,)Configuration for Tinker-based supervised fine-tuning.
This configuration is used to set up LoRA-based SFT training with the Tinker framework.
Example
config = TinkerSFTConfig( base_model=“meta-llama/Llama-3.1-8B-Instruct”, learning_rate=1e-4, steps=100, lora_rank=16, )
adam_beta1
Section titled “adam_beta1”adam_beta1: float = 0.9Adam beta1 parameter.
adam_beta2
Section titled “adam_beta2”adam_beta2: float = 0.95Adam beta2 parameter.
adam_eps
Section titled “adam_eps”adam_eps: float = 1e-08Adam epsilon parameter.
base_model
Section titled “base_model”base_model: str = 'meta-llama/Llama-3.1-8B-Instruct'Model name or path for the base model to fine-tune.
base_url
Section titled “base_url”base_url: str | None = NoneTinker service URL. If None, uses default from environment.
batch_size
Section titled “batch_size”batch_size: int = 16Number of sequences per training step.
checkpoint_interval
Section titled “checkpoint_interval”checkpoint_interval: int = 10Save checkpoint every N training steps.
data_dir
Section titled “data_dir”data_dir: str = 'data'Directory containing parquet dataset files.
eval_split
Section titled “eval_split”eval_split: str | None = 'test'Prefix for evaluation data files. Set to None to skip eval.
gradient_accumulation_steps
Section titled “gradient_accumulation_steps”gradient_accumulation_steps: int = 1Number of micro-batches to accumulate before each optimizer step.
learning_rate
Section titled “learning_rate”learning_rate: float = 0.0001Adam optimizer learning rate.
lora_rank
Section titled “lora_rank”lora_rank: int = 16LoRA rank parameter for adapter training.
max_eval_examples
Section titled “max_eval_examples”max_eval_examples: int | None = NoneMaximum number of evaluation examples. None for all.
max_new_tokens
Section titled “max_new_tokens”max_new_tokens: int = 64Maximum new tokens when sampling.
max_sequence_length
Section titled “max_sequence_length”max_sequence_length: int = 2048Maximum sequence length for tokenization (truncates from left).
max_train_examples
Section titled “max_train_examples”max_train_examples: int | None = NoneMaximum number of training examples. None for all.
num_samples
Section titled “num_samples”num_samples: int = 4Number of samples to generate after training.
project
Section titled “project”project: str | None = NoneDreadnode project name for logging.
run_name
Section titled “run_name”run_name: str | None = NoneDreadnode run name.
sample_prompt
Section titled “sample_prompt”sample_prompt: str = ''Prompt used for sampling after training.
seed: int = 0Random seed for batch selection.
skip_sample
Section titled “skip_sample”skip_sample: bool = FalseSkip sampling after training checkpoints.
steps: int = 100Total number of training steps.
tags: list[str] = field( default_factory=lambda: ["training", "sft", "tinker"])Tags for the Dreadnode run.
temperature
Section titled “temperature”temperature: float = 0.0Sampling temperature (0.0 for greedy).
train_split
Section titled “train_split”train_split: str = 'train'Prefix for training data files (e.g., ‘train_*.parquet’).
__post_init__
Section titled “__post_init__”__post_init__() -> NoneValidate configuration after initialization.
TinkerSFTTrainer
Section titled “TinkerSFTTrainer”TinkerSFTTrainer( config: TinkerSFTConfig, training_client: TrainingClient | None = None, service_client: ServiceClient | None = None, callbacks: Sequence[TrainingCallback] | None = None,)Trainer for supervised fine-tuning using Tinker with LoRA.
This trainer provides:
- LoRA-based fine-tuning via Tinker service
- Checkpoint saving and artifact logging
- Optional sampling after training
- Integration with Dreadnode for experiment tracking
Example
Create configuration
Section titled “Create configuration”config = TinkerSFTConfig( base_model=“meta-llama/Llama-3.1-8B-Instruct”, steps=100, lora_rank=16, )
Create trainer
Section titled “Create trainer”trainer = TinkerSFTTrainer(config)
state = trainer.train(train_data) print(f”Final loss: {state.losses[-1]:.4f}”)
Initialize the Tinker SFT trainer.
Parameters:
config(TinkerSFTConfig) –Training configuration.training_client(TrainingClient | None, default:None) –Optional pre-initialized Tinker training client.service_client(ServiceClient | None, default:None) –Optional pre-initialized Tinker service client.callbacks(Sequence[TrainingCallback] | None, default:None) –Optional list of training callbacks.
renderer
Section titled “renderer”renderer: AnyGet the model-specific renderer (initializes clients if needed).
service_client
Section titled “service_client”service_client: ServiceClientGet the service client (initializes clients if needed).
tokenizer
Section titled “tokenizer”tokenizer: AnyGet the tokenizer (initializes clients if needed).
training_client
Section titled “training_client”training_client: TrainingClientGet the training client (initializes clients if needed).
add_callback
Section titled “add_callback”add_callback(callback: TrainingCallback) -> NoneAdd a training callback.
evaluate
Section titled “evaluate”evaluate( eval_data: list[Datum], step: int = 0, log_to_dreadnode: bool = True,) -> floatRun evaluation on the provided data.
Parameters:
eval_data(list[Datum]) –Evaluation data as Tinker Datum objects.step(int, default:0) –Current training step (for logging).log_to_dreadnode(bool, default:True) –Whether to log metrics to Dreadnode.
Returns:
float–Evaluation loss.
sample
Section titled “sample”sample() -> list[dict[str, str]]Generate samples from the fine-tuned model.
Returns:
list[dict[str, str]]–List of sample dictionaries with ‘prompt’ and ‘completion’ keys.
save_checkpoint
Section titled “save_checkpoint”save_checkpoint(name: str | None = None) -> strSave the current model weights as a checkpoint.
Parameters:
name(str | None, default:None) –Optional checkpoint name.
Returns:
str–Path to the saved checkpoint.
train( train_data: list[Datum], eval_data: list[Datum] | None = None, log_to_dreadnode: bool = True,) -> TrainingStateRun supervised fine-tuning.
Parameters:
train_data(list[Datum]) –Training data as Tinker Datum objects.eval_data(list[Datum] | None, default:None) –Optional evaluation data.log_to_dreadnode(bool, default:True) –Whether to log metrics to Dreadnode.
Returns:
TrainingState–Final training state.
Raises:
ValueError–If training data is empty.
TrainingModel
Section titled “TrainingModel”One base model available for hosted training jobs.
TrainingModelPricing
Section titled “TrainingModelPricing”Optional upstream pricing metadata.
All values are USD per million tokens. None means “not published” —
callers should fall back to the live Tinker console for authoritative
numbers (pricing changes faster than we can update the SDK).
VerificationResult
Section titled “VerificationResult”VerificationResult( passed: bool, score: float, metrics: dict[str, Any] = dict(),)Outcome of grading a rollout against a task’s verification config.
Attributes:
passed(bool) –Whether the task was considered solved.score(float) –Scalar in[0, 1]. For binary env_flag / env_script this is1.0on pass and0.0on fail. Forllm_judgethis is the judge’s rubric score.metrics(dict[str, Any]) –Free-form metadata attached to traces and training metrics (method,exit_code, judgereasonand attributes, …).
__getattr__
Section titled “__getattr__”__getattr__(name: str) -> t.AnyLazy load training components to avoid importing torch/ray at module load.
batched_environments
Section titled “batched_environments”batched_environments( envs: list[TaskEnvironment], *, max_concurrent_setup: int = 32,) -> AsyncIterator[list[TaskEnvironment]]Provision a batch of envs in parallel; tear them all down on exit.
Caps concurrent setup via a semaphore so a 64-rollout RL step doesn’t
pummel the sandbox provider at batch boundaries. Envs that fail setup()
are logged and excluded from the yielded list; their teardown() is
not called (nothing to tear down). Envs that succeeded setup are always
torn down on exit — even if the caller raises inside the async with
block.
Parameters:
envs(list[TaskEnvironment]) –Pre-constructedTaskEnvironmentinstances. They must not already be set up (setup()is called by this context manager).max_concurrent_setup(int, default:32) –Maximum concurrentsetup()calls. Defaults to 32; tune down under tight provider quota.
Yields:
AsyncIterator[list[TaskEnvironment]]–The live envs (those that succeededsetup()), in the input orderAsyncIterator[list[TaskEnvironment]]–with failed envs skipped.
Example::
envs = [ TaskEnvironment(api_client=api, org=ORG, workspace=WS, task_ref="pwn/flag", inputs=row.get("inputs")) for row in batch_rows]async with batched_environments(envs, max_concurrent_setup=8) as live: rewards = await asyncio.gather(*[score(env) for env in live])run_in_sandbox
Section titled “run_in_sandbox”run_in_sandbox( code: str, timeout_seconds: int = 300, memory_mb: int = 2048,) -> dictRun code in a Prime Intellect sandbox.
Sandboxes are lightweight execution environments for running AI-generated code or quick experiments.
Parameters:
code(str) –Python code to execute.timeout_seconds(int, default:300) –Execution timeout.memory_mb(int, default:2048) –Memory limit in MB.
Returns:
dict–Dict with stdout, stderr, and return_code.
Example
result = await run_in_sandbox(''' import torch print(f”CUDA available: {torch.cuda.is_available()}”) ''') print(result[“stdout”])
train_dpo
Section titled “train_dpo”train_dpo( config_dict: dict[str, Any], prompts: list[str]) -> t.AnyTrain with DPO.
train_grpo
Section titled “train_grpo”train_grpo( config_dict: dict[str, Any], prompts: list[str], reward_fn: Callable[..., Any],) -> t.AnyTrain with GRPO.
train_on_prime
Section titled “train_on_prime”train_on_prime( config: dict[str, Any] | None = None, name: str | None = None, gpu_type: str = "H100_80GB", gpu_count: int = 1, training_type: str = "sft", requirements: list[str] | None = None, env_vars: dict[str, str] | None = None, auto_terminate: bool = True, region: str | None = None, interruptible: bool = False,) -> TrainingResultRun training on Prime Intellect infrastructure.
This function provides a high-level interface for running training jobs on Prime’s decentralized GPU compute.
Parameters:
config(dict[str, Any] | None, default:None) –Training configuration dict. Common options:- model_name: Model name or path
- max_steps: Maximum training steps
- batch_size: Batch size per device
- learning_rate: Learning rate
- checkpoint_dir: Checkpoint directory
name(str | None, default:None) –Job name.gpu_type(str, default:'H100_80GB') –GPU type (H100_80GB, A100_80GB, etc.).gpu_count(int, default:1) –Number of GPUs.training_type(str, default:'sft') –Type of training (sft, grpo, dpo, ppo).requirements(list[str] | None, default:None) –Additional Python requirements.env_vars(dict[str, str] | None, default:None) –Environment variables.auto_terminate(bool, default:True) –Terminate pods after training.region(str | None, default:None) –Preferred region.interruptible(bool, default:False) –Use spot/interruptible instances.
Returns:
TrainingResult–TrainingResult with final state and checkpoint info.
Example
SFT training on H100s
Section titled “SFT training on H100s”result = await train_on_prime( config={ “model_name”: “meta-llama/Llama-3.1-8B-Instruct”, “max_steps”: 1000, “batch_size”: 32, }, gpu_type=“H100_80GB”, gpu_count=8, )
if result.succeeded: print(f”Checkpoint: {result.checkpoint_path}“)
train_ppo
Section titled “train_ppo”train_ppo( config_dict: dict[str, Any], prompts: list[str], reward_fn: Callable[..., Any],) -> t.AnyTrain with PPO.
train_sft
Section titled “train_sft”train_sft( config_dict: dict[str, Any], prompts: list[str]) -> t.AnyTrain with SFT.
train_tinker_sft
Section titled “train_tinker_sft”train_tinker_sft( config: dict[str, Any] | None = None, messages: Sequence[list[dict[str, str]]] | None = None, examples: Sequence[tuple[str, str]] | None = None, data_dir: str | None = None, project: str | None = None, run_name: str | None = None, tags: list[str] | None = None, log_to_dreadnode: bool = True,) -> TrainingStateTrain a model using Tinker SFT.
This function provides a high-level interface for supervised fine-tuning using the Tinker framework. Data can be provided in multiple formats:
- Conversation messages (list of message dicts)
- Simple examples (input/output pairs)
- Parquet files in a data directory
Parameters:
config(dict[str, Any] | None, default:None) –Training configuration dict. See TinkerSFTConfig for options.messages(Sequence[list[dict[str, str]]] | None, default:None) –List of conversations, each a list of message dicts with ‘role’ and ‘content’ keys.examples(Sequence[tuple[str, str]] | None, default:None) –List of (input, output) tuples for simple supervised learning.data_dir(str | None, default:None) –Directory containing parquet files with training data.project(str | None, default:None) –Dreadnode project name.run_name(str | None, default:None) –Dreadnode run name.tags(list[str] | None, default:None) –Tags for the Dreadnode run.log_to_dreadnode(bool, default:True) –Whether to log to Dreadnode (default: True).
Returns:
TrainingState–TrainingState with training metrics and checkpoint paths.
Raises:
ValueError–If no data source is provided.
verify_env_state
Section titled “verify_env_state”verify_env_state( env: TaskEnvironment, trajectory: Trajectory | None, verification: dict[str, Any] | None, *, judge_context: dict[str, Any] | None = None,) -> VerificationResultGrade the rollout against the task’s verification config.
Supports three dispatch keys on the verification dict:
env_flag— read a file from the env sandbox; compare against a sha256 hash (hash) or plaintextexpectedvalue.env_script— execute a script inside the env; pass iff the exit code matchesexpected_exit_code(default 0).llm_judge— scoretrajectorywith :class:~dreadnode.agents.AgentJudgeagainst a rubric; pass iff score clearspassing_threshold.
Parameters:
env(TaskEnvironment) –A provisioned :class:TaskEnvironmentwithexecute()available.trajectory(Trajectory | None) –The agent’s rollout. Required forllm_judge; ignored byenv_flag/env_script. PassNonefor single-shot recipes that don’t produce a trajectory.verification(dict[str, Any] | None) –The task’s verification config (typically fromenv.task_verification).Noneor missingmethodraisesValueError.judge_context(dict[str, Any] | None, default:None) –Optional context passed through toAgentJudge.evaluatewhenmethod=llm_judge. Good for task instruction / env state.
Returns:
A(VerificationResult) –class:VerificationResult.
Raises:
ValueError–ifverificationis missing, method is unknown, or the chosen method’s required fields are absent.RuntimeError–ifenv_flag/env_scriptinvocation is attempted against an un-provisioned env (caller mustsetup()first).