训练显存要求并不高,时间要求会很高

Tokenizer

是什么

  • 类似图片使用像素表征自然图像(RGB→色块)
  • 分词器用数值表征自然语言(ID→语言块)

怎么做

  • BPE
    • 计数所有相邻两个字符的组合,选择最多的一对合并为新的单元字符,持续迭代至词典大小(最后为每个字符与合并创建的各个新字符组成)
    • 在词典大小允许情况下,常见词会变成单元字符,不常见的会存在其一部分,如unfamiliar会存在un(如果其足够频繁)这样的一部分

Model

总览

  • 对比原本的Transformer,结构上主要有以下变化:
    • Norm: LayerNorm → RMSNorm
    • Attention: MHA → GQA
    • Positional Encoding: Sinusoidal Positional Encoding → ROPE
notion image
notion image

python

class MiniMindConfig(PretrainedConfig): model_type = "minimind" def __init__( self, dropout: float = 0.0, # Dropout probability (0 = no dropout) bos_token_id: int = 1, # <|im_start|> token ID eos_token_id: int = 2, # <|im_end|> token ID hidden_act: str = 'silu', # Activation function (SiLU = Swish) hidden_size: int = 512, # d_model - embedding dimension intermediate_size: int = None, # FFN hidden dim (auto-computed if None) max_position_embeddings: int = 32768, # Max sequence length num_attention_heads: int = 8, # Number of query heads (Q) num_hidden_layers: int = 8, # Number of transformer blocks num_key_value_heads: int = 2, # Number of KV heads (for GQA) vocab_size: int = 6400, # Vocabulary size rms_norm_eps: float = 1e-05, # RMSNorm epsilon for stability rope_theta: int = 1000000.0, # RoPE base frequency flash_attn: bool = True, # Use Flash Attention if available # ... MoE configs omitted for clarity ):
Python

RMSNorm

  • 令μ = 0,带入LayerNorm
  • 意味着不做平移,只做缩放,减少计算要求,同时仍能维持训练稳定性
  • LayerNorm
notion image
notion image
notion image
  • RMSNorm
notion image
notion image

python

class RMSNorm(torch.nn.Module): """ RMSNorm is simpler and faster than LayerNorm. Instead of: (x - mean) / std * gamma + beta RMSNorm does: x / RMS(x) * gamma No mean subtraction, no beta - just scale normalization. """ def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps # Small constant for numerical stability self.weight = nn.Parameter(torch.ones(dim)) # Learnable scale (gamma) def _norm(self, x): # Compute RMS: sqrt(mean(x^2)) # rsqrt = 1/sqrt for efficiency return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) # Shape: [..., dim] → [..., dim] (unchanged) def forward(self, x): # Cast to float for precision, normalize, cast back, scale return self.weight * self._norm(x.float()).type_as(x)
Python

GQA

  • 将Queries分成x组,每组内合用一对kv
notion image
  • 一些相比原本的Attention(Transformer中)额外加的东西
    • 分组复用的kv主要通过更小的线性层输出体现
    • kv cache用于decode推演时缓存(prefill为user prompt → 首个token,decode为自回归,之前的 → 下一个),在prefill阶段,一次性并行处理完user prompt,seq_len即为其tokenizer后长度,decode时,恒为1,在计算完kv与应用rope后,拼接进kv cache,seq长度变为上次长度+1
      • notion image

python

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """ Repeat KV heads to match the number of query heads. With GQA, we have fewer KV heads than Q heads. Example: 8 Q heads, 2 KV heads → each KV head serves 4 Q heads Args: x: [batch, seq_len, num_kv_heads, head_dim] n_rep: Repetition factor (num_q_heads // num_kv_heads) Returns: [batch, seq_len, num_q_heads, head_dim] """ bs, slen, num_key_value_heads, head_dim = x.shape if n_rep == 1: # Standard MHA, no repetition needed return x # Expand and reshape: [b, s, kv, d] → [b, s, kv, rep, d] → [b, s, kv*rep, d] return ( x[:, :, :, None, :] .expand(bs, slen, num_key_value_heads, n_rep, head_dim) .reshape(bs, slen, num_key_value_heads * n_rep, head_dim) ) class Attention(nn.Module): def __init__(self, args: MiniMindConfig): super().__init__() # GQA setup: 8 query heads, 2 KV heads self.num_key_value_heads = args.num_key_value_heads # 2 self.n_local_heads = args.num_attention_heads # 8 self.n_local_kv_heads = self.num_key_value_heads # 2 self.n_rep = self.n_local_heads // self.n_local_kv_heads # 8/2 = 4 # Head dimension: 512 / 8 = 64 per head self.head_dim = args.hidden_size // args.num_attention_heads # 64 # Projections # Q: [512] → [8 * 64] = [512] self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False) # K: [512] → [2 * 64] = [128] (fewer because GQA) self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) # V: [512] → [2 * 64] = [128] self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) # O: [512] → [512] self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False) self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn def forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None): bsz, seq_len, _ = x.shape # [batch, seq, 512] # Project to Q, K, V xq = self.q_proj(x) # [batch, seq, 512] xk = self.k_proj(x) # [batch, seq, 128] xv = self.v_proj(x) # [batch, seq, 128] # Reshape to multi-head format xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) # [b, s, 8, 64] xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) # [b, s, 2, 64] xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) # [b, s, 2, 64] # Apply RoPE cos, sin = position_embeddings xq, xk = apply_rotary_pos_emb(xq, xk, cos[:seq_len], sin[:seq_len]) # KV Cache for autoregressive generation if past_key_value is not None: xk = torch.cat([past_key_value[0], xk], dim=1) # Append new KV xv = torch.cat([past_key_value[1], xv], dim=1) past_kv = (xk, xv) if use_cache else None # Prepare for attention: transpose to [batch, heads, seq, dim] xq = xq.transpose(1, 2) # [b, 8, s, 64] xk = repeat_kv(xk, self.n_rep).transpose(1, 2) # [b, 8, s, 64] (repeated!) xv = repeat_kv(xv, self.n_rep).transpose(1, 2) # [b, 8, s, 64] # Compute attention if self.flash and seq_len > 1: # Flash Attention: O(n) memory, faster output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) else: # Standard attention: Q @ K^T / sqrt(d) → softmax → @ V scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) # [b, 8, seq_q, seq_k] # Causal mask: prevent attending to future tokens scores = scores + torch.triu( torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1 ).unsqueeze(0).unsqueeze(0) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = scores @ xv # [b, 8, s, 64] # Reshape back and project output = output.transpose(1, 2).reshape(bsz, seq_len, -1) # [b, s, 512] output = self.resid_dropout(self.o_proj(output)) return output, past_kv
Python

ROPE

  • 记录相对位置,而非绝对位置 → 外推性好(训练在短文本如512,但可以在768,1024等更长文本仍有不错表现)
  • 位置编码用于表达位置信息,对m位置向量(embedding)与n位置向量建模相对位置m-n可通过空间旋转建模,有如下过程:
notion image

python

def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: Optional[dict] = None): """ Precompute sin/cos tables for RoPE. RoPE encodes position by ROTATING the query/key vectors. Position i rotates vectors by angle i*θ, where θ depends on dimension. Args: dim: Dimension per head (hidden_size // num_attention_heads = 512//8 = 64) end: Maximum sequence length to precompute rope_base: Base for frequency computation (higher = longer context) """ # Compute frequencies for each dimension pair # freqs[i] = 1 / (base^(2i/dim)) for i = 0, 1, 2, ... dim/2 freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # Shape: [dim/2] = [32] for dim=64 # For each position t, compute t * freq t = torch.arange(end) # [0, 1, 2, ..., 32767] freqs = torch.outer(t, freqs).float() # [32768, 32] - angles for each position # Precompute cos and sin (doubled for complex rotation) freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) # [32768, 64] freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) # [32768, 64] return freqs_cos, freqs_sin def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ Apply rotary position embeddings to queries and keys. The rotation formula: q_rotated = q * cos + rotate_half(q) * sin This encodes relative position: dot(q_i, k_j) depends on (i-j). """ def rotate_half(x): # Split x into two halves and swap with sign change # [a, b, c, d, e, f] → [-d, -e, -f, a, b, c] return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1) # Apply rotation: q' = q*cos + rotate(q)*sin q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim)) k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim)) return q_embed, k_embed
Python
 
 

SwiGLU FFN

  • 多出个门控,也就是在扩中间层时候,分两个Linear出来,一个是原本的,一个是做完后走SiLU激活,然后点对点乘,实现门控筛选
notion image

python

class FeedForward(nn.Module): """ SwiGLU FFN: gate * up, then down projection. Standard FFN: x → Linear → ReLU → Linear → out SwiGLU FFN: x → [SiLU(gate(x)) * up(x)] → down → out The gating mechanism improves gradient flow and performance. """ def __init__(self, config: MiniMindConfig): super().__init__() # Compute intermediate size (if not specified) if config.intermediate_size is None: # Rule of thumb: 8/3 * hidden_size, rounded to multiple of 64 intermediate_size = int(config.hidden_size * 8 / 3) # 512 * 8/3 ≈ 1365 config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64) # → 1408 # Three projections (SwiGLU uses gate + up instead of just one) self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) # [512] → [1408] self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) # [512] → [1408] self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) # [1408] → [512] self.dropout = nn.Dropout(config.dropout) self.act_fn = ACT2FN[config.hidden_act] # SiLU activation def forward(self, x): # SwiGLU: SiLU(gate(x)) * up(x), then down project # x: [batch, seq, 512] gate = self.act_fn(self.gate_proj(x)) # [batch, seq, 1408] up = self.up_proj(x) # [batch, seq, 1408] hidden = gate * up # Element-wise gating output = self.down_proj(hidden) # [batch, seq, 512] return self.dropout(output)
Python

Transformer Block

python

class MiniMindBlock(nn.Module): """ One transformer block = Attention + FFN with residual connections. Structure: x → RMSNorm → Attention → + x (residual) → RMSNorm → FFN → + (residual) This is "Pre-Norm" architecture (normalize before, not after). """ def __init__(self, layer_id: int, config: MiniMindConfig): super().__init__() self.layer_id = layer_id # Attention sublayer self.self_attn = Attention(config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # FFN sublayer (could be regular FFN or MoE) self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None): # hidden_states: [batch, seq, 512] # --------------- Attention Sublayer --------------- residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Pre-norm hidden_states, present_key_value = self.self_attn( hidden_states, position_embeddings, past_key_value, use_cache, attention_mask ) hidden_states = hidden_states + residual # Residual connection # --------------- FFN Sublayer --------------- # Note: We add to hidden_states, not to a new residual hidden_states = hidden_states + self.mlp( self.post_attention_layernorm(hidden_states) # Pre-norm ) return hidden_states, present_key_value
Python

MiniMind

  • 若果没有kv cache,就是prefill,此时开始位置为0,输入为全局,一次算出user prompt里所有的kv入cache,当此次过后,有kv cache,长度为user prompt tokenize后,其长度对应下标为起始位置,输入为单个token(prefill得到的首个token,之后是每次的上次生成的单个token),算出自己的kv入cache,自己的q与kv cache算结果(即只需要输入这一token,便通过cache得到所有截止当前的所有kv,从user prompt开头到当前token)

python

class MiniMindModel(nn.Module): """ Stack of transformer blocks with embedding and final norm. Flow: token_ids → embed → [block × 8] → norm → hidden_states """ def __init__(self, config: MiniMindConfig): super().__init__() self.config = config # Token embedding: 6400 tokens → 512 dimensions self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.dropout = nn.Dropout(config.dropout) # Stack of transformer blocks self.layers = nn.ModuleList([ MiniMindBlock(layer_id, config) for layer_id in range(config.num_hidden_layers) # 8 blocks ]) # Final normalization self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Precompute RoPE frequencies freqs_cos, freqs_sin = precompute_freqs_cis( dim=config.hidden_size // config.num_attention_heads, # 64 end=config.max_position_embeddings, # 32768 rope_base=config.rope_theta # 1e6 ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False): batch_size, seq_length = input_ids.shape # [batch, seq] # Embed tokens hidden_states = self.embed_tokens(input_ids) # [batch, seq, 512] hidden_states = self.dropout(hidden_states) # Get position embeddings for this sequence start_pos = past_key_values[0][0].shape[1] if past_key_values and past_key_values[0] else 0 position_embeddings = ( self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length] ) # Pass through all transformer blocks presents = [] past_key_values = past_key_values or [None] * len(self.layers) for layer, past_kv in zip(self.layers, past_key_values): hidden_states, present = layer( hidden_states, position_embeddings, past_key_value=past_kv, use_cache=use_cache, attention_mask=attention_mask ) presents.append(present) # Final normalization hidden_states = self.norm(hidden_states) # [batch, seq, 512] return hidden_states, presents, aux_loss class MiniMindForCausalLM(PreTrainedModel, GenerationMixin): """ Complete causal language model with output head. Flow: input_ids → MiniMindModel → lm_head → logits """ def __init__(self, config: MiniMindConfig = None): self.config = config or MiniMindConfig() super().__init__(self.config) # The transformer self.model = MiniMindModel(self.config) # Output projection: [512] → [6400] (vocabulary) self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) # WEIGHT TYING: embed_tokens and lm_head share the same weight matrix! # This saves parameters and forces consistency between input/output representations self.model.embed_tokens.weight = self.lm_head.weight def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, logits_to_keep=0): # Get hidden states from transformer hidden_states, past_key_values, aux_loss = self.model( input_ids, attention_mask, past_key_values, use_cache ) # hidden_states: [batch, seq, 512] # Project to vocabulary logits = self.lm_head(hidden_states) # [batch, seq, 6400] return CausalLMOutputWithPast( logits=logits, past_key_values=past_key_values, hidden_states=hidden_states )
Python
 

Embedding

  • 作用:ID → Vector
    • 即将tokenizer输出结果转为对应hidden_size大小的vector,数值初始随机,在学习过程中调整

javascript

Text: "你好" (Tokenizer) Token IDs: [234, 567] (example IDs) (Embedding Layer) Vectors: [[0.12, -0.34, ..., 0.56], # 512-dim vector for token 234 [0.78, 0.91, ..., -0.23]] # 512-dim vector for token 567Shape: [batch, seq_len, hidden_size] → e.g., [1, 2, 512]
JavaScript
  • 共享输入输出的转换权重,节省空间

Pretrain

是什么

  • 任务为通过之前的tokens预测下一token(causal language modeling)

javascript

Input: "The cat sat on the" Target: "mat" The model learns: P(next_token | previous_tokens)
JavaScript
  • 学习通用语言理解
    • Learning Objective
      Example
      Grammar
      "The cat sits" not "The cat sit"
      Facts
      "Paris is the capital of France"
      Patterns
      Code syntax, mathematical notations
      Multilingual
      Chinese characters, English words
      Context
      Long-range dependencies

怎么做

数据

  • 数据量相对大,可以通过head命令看一眼数据pattern
notion image
  • 数据内容均为纯文本,jsonl里存放的每个json对象只有text属性,内容为一句句话,每句开头结尾为特殊token
 
notion image
  • 特殊token的配置
    • notion image
  • 相关代码
    • 读入文本,对过长的截断,对过短的padding,进行tokenizer编码,定义loss_mask对padding出来的内容不做loss计算,由于任务为下一token预测,开始的token不为标签,结尾的token不为输入,开始的token也不算loss

python

class PretrainDataset(Dataset): def __getitem__(self, index): sample = self.samples[index] # Tokenize the raw text encoding = self.tokenizer( str(sample['text']), max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding.input_ids.squeeze() # Loss mask: compute loss on ALL tokens (except padding) loss_mask = (input_ids != self.tokenizer.pad_token_id) # Shift for next-token prediction X = torch.tensor(input_ids[:-1], dtype=torch.long) # Input tokens Y = torch.tensor(input_ids[1:], dtype=torch.long) # Target tokens (shifted by 1) loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) return X, Y, loss_mask
Python

训练

  • loss_mask盖住padding部分
  • 选择CE Loss

python

def train_epoch(epoch, loader, iters, start_step=0, wandb=None): loss_fct = nn.CrossEntropyLoss(reduction='none') for step, (X, Y, loss_mask) in enumerate(loader): X = X.to(args.device) # [batch, seq_len-1] Y = Y.to(args.device) # [batch, seq_len-1] (shifted targets) loss_mask = loss_mask.to(args.device) # Forward pass res = model(X) # Compute loss on each position loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), # [batch*seq, vocab_size] Y.view(-1) # [batch*seq] ).view(Y.size()) # Apply mask and average loss = (loss * loss_mask).sum() / loss_mask.sum() # Backward and optimize scaler.scale(loss).backward()
Python

结果

  • 单卡4090 - 使用MiniMind2 104M参数配置
  • ~4h完成单epoch
notion image
notion image
notion image
  • 与仓库效果图拟合
notion image

SFT/IFT

是什么

  • Supervised Fine-Tuning (SFT) / Instruction Fine-Tuning
  • 让预训练的模型适配指令遵循与对话,而不是仅续写

怎么做

数据

  • 一个json对象里有conversations属性,其中含一个个句子,每个由content与对应role组成,后续只有assistant对应role求loss
notion image
  • 将assistant的部分定为起始id与结束id

python

class SFTDataset(Dataset): def __init__(self, jsonl_path, tokenizer, max_length=1024): self.tokenizer = tokenizer self.max_length = max_length self.samples = self.load_data(jsonl_path) # Markers for finding assistant responses self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant', add_special_tokens=False).input_ids self.eos_id = tokenizer(f'{tokenizer.eos_token}', add_special_tokens=False).input_ids
Python
  • 找上方对assistant部分定义的起始与结束id,对中间内容标记mask为1,约束仅对此部分计算loss

python

def _generate_loss_mask(self, input_ids): """Only compute loss on ASSISTANT responses""" loss_mask = [0] * len(input_ids) # Start with all zeros i = 0 while i < len(input_ids): # Find "<|im_start|>assistant" if input_ids[i:i + len(self.bos_id)] == self.bos_id: start = i + len(self.bos_id) end = start # Find the corresponding "<|im_end|>" while end < len(input_ids): if input_ids[end:end + len(self.eos_id)] == self.eos_id: break end += 1 # Set mask to 1 ONLY for assistant response tokens for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): loss_mask[j] = 1 i = end + len(self.eos_id) else: i += 1 return loss_mask
Python

python

Token: <|im_start|> system \n You... <|im_end|> <|im_start|> user \n What... <|im_end|> <|im_start|> assistant \n The answer is 4 . <|im_end|> Mask: 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 ↑ LOSS COMPUTED ONLY HERE ↑
Python

训练

  • 数据格式与计算出的loss mask pretrain与sft唯一的区别
  • 对比train_pretrain.py与train_full_sft.py,在训练逻辑上没有任何区别
notion image

结果

  • 并非完全连续训,但确实训练时间很长,应该不少于两天
notion image
notion image
  • 难评,看上去都是震荡的
notion image
notion image

RL

  • P → Policy,待训练LLM充当RL领域对应角色,O → Optimization
  • Preference|Reasoning Fine-Tuning
  • DPO | RLHF(Human Feedback) & RLVR(Verifiable Rewards)(PPO → GRPO → SPO)

DPO

  • 对比学习,要求数据集里准备Chosen与Rejected对
  • 通过与冻住的基线模型结果作差算Loss
  • 有限程度地最大化chosen与rejected概率差
notion image
  • 建模偏好为概率
  • log → 数值连乘改求和
  • - → 最大改最小
  • bata约束相比ref模型偏离程度
  • 代码中调整顺序

python

import torch import torch.nn.functional as F def dpo_loss(policy_logprobs, ref_logprobs, chosen_indices, rejected_indices, beta=0.1): """ policy_logprobs: Log probabilities from the model being trained (Batch, Seq_Len) ref_logprobs: Log probabilities from the frozen reference model (Batch, Seq_Len) """ # Gather log probs for chosen and rejected responses # (Simplified gathering logic for illustration) pi_logr_chosen = gather_logprobs(policy_logprobs, chosen_indices) pi_logr_rejected = gather_logprobs(policy_logprobs, rejected_indices) ref_logr_chosen = gather_logprobs(ref_logprobs, chosen_indices) ref_logr_rejected = gather_logprobs(ref_logprobs, rejected_indices) # Calculate log ratios pi_logr_ratio = pi_logr_chosen - pi_logr_rejected ref_logr_ratio = ref_logr_chosen - ref_logr_rejected # The DPO implicit reward difference logits = pi_logr_ratio - ref_logr_ratio # Binary Cross Entropy Loss (sigmoid internally) loss = -F.logsigmoid(beta * logits).mean() return loss
Python

GRPO

  • 增大相比组内其他答案的优势,减小相对原本的偏移
notion image
notion image

python

import torch import torch.nn.functional as F def grpo_loss( current_logprobs, # Log probs of the model we are training old_logprobs, # Log probs from when we generated the samples (fixed) ref_logprobs, # Log probs from the reference model (frozen) rewards, # Raw scores [Batch, Group_Size] beta=0.04, # KL penalty weight clip_eps=0.2 # PPO clipping parameter ): """ Assume inputs are shaped (Batch, Group_Size, Seq_Len) """ # --- 1. Calculate Group Advantages --- # We compute statistics along dim=1 (the group dimension) mean_rewards = rewards.mean(dim=1, keepdim=True) std_rewards = rewards.std(dim=1, keepdim=True) # Standardize (Z-Score): "Grading on a curve" advantages = (rewards - mean_rewards) / (std_rewards + 1e-4) # --- 2. Calculate PPO-style Ratio --- # ratio = exp( log_current - log_old ) # Note: Usually we sum logprobs over the sequence length first token_ratio = (current_logprobs - old_logprobs).exp() # --- 3. Compute Surrogate Loss (Clipped) --- # Unclipped part: Ratio * Advantage surr1 = token_ratio * advantages.unsqueeze(-1) # Clipped part: Clamp(Ratio) * Advantage surr2 = torch.clamp(token_ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages.unsqueeze(-1) # PPO Objective: Take the min (pessimistic bound) policy_loss = -torch.min(surr1, surr2).mean() # --- 4. KL Divergence Penalty --- # Approximate KL: log_current - log_ref # We want to minimize KL, so we add it to the loss kl_div = (current_logprobs - ref_logprobs) # This is actually log-ratio, approximates KL kl_loss = beta * kl_div.mean() # Final GRPO Loss total_loss = policy_loss + kl_loss return total_loss
Python

SPO

PPO

 

Reason

Lora

Distillation

 
θ角

参考与拓展阅读

 
[2026.1.17]实验室环境服务器使用实践从2s到0.9s的博客优化复盘
Loading...
2024-2026CamelliaV.

CamelliaV | Java;前端;AI;ACGN;


  1. 1 BRIGHTEST HEART 高柳知葉
  2. 2 Raven Oliver Good
  3. 3 Against the Tide(逆潮) 鸣潮先约电台/Forts
  4. 4 给予你的爱 Xi YuaN/Digital Vengeance/唢清
  5. 5 スペルビア帝国/夜 平松建治
  6. 6 Imagination QQHHh
  7. 7 virtues QQHHh
  8. 8 Tricolor (short ver.) Digital Vengeance/44
  9. 9 港口夜 - 四周年 月代彩
  10. 10 神よ、その黄昏よ 金﨑猛
  11. 11 絆炎 (English Ver) Katherine Eames
  12. 12 ラストエンゲージ~祈りの呪文 馬場泰久
  13. 13 an evening calm fripSide
  14. 14 フレスベルグの少女~風花雪月~ Caro
  15. 15 Answer 北原春希/小木曽雪菜
  16. 16 Kiss Kiss Kiss BENI
  17. 17 远航高歌 染音若蔡/阿南
  18. 18 Sentimental Blue Trident
  19. 19 目指す先にあるもの Falcom Sound Team J.D.K.
  20. 20 Night City r e l/Artemis Delta
  21. 21 Gimme×Gimme P*Light/Giga/初音ミク/鏡音リン
  22. 22 桃幻浪漫 Airots/Active Planets & AUGUST
  23. 23 DESIRE 美郷あき
  24. 24 镜花堂(feat.芬璃尔) 幻塔手游/Rux
  25. 25 she was sitting under the osmanthus tree 梶浦由記
BRIGHTEST HEART - 高柳知葉
00:00 / 04:02