强化学习环境设计:构建 LLM 对齐所需的模拟世界

FreeGuideOnline 最新 2026-06-29

“用户:请解释相对论。\n助手:”

智能体收到上述文本,需要决定如何延续这段对话。

### 1.2 动作(Action)

动作是智能体可以执行的操作。对于语言模型,**动作就是生成的下一个 token(或 token 序列)**。动作空间的大小等于词表大小,极其庞大,这也是 LLM 强化学习的难点之一。

### 1.3 奖励(Reward)

奖励是环境对智能体动作的即时评分。在传统 RL 中,奖励可能是游戏的得分;在 LLM 对齐中,**奖励来自一个经过训练的人类偏好模型**——它会对“有帮助、无害、诚实”的回答给出高分,对有害或不相关的回答给予低分。

### 1.4 转移函数(Transition Function)

转移函数描述了给定当前状态和动作,环境如何转移到下一个状态。在对话环境中,转移非常简单:**将动作(生成的文本)追加到当前历史之后,形成新的状态**。

数学表达:

S_{t+1} = S_t + "助手:" + A_t + "\n用户:" + 下一用户输入

注意:有些环境中“下一用户输入”也是由环境模拟生成的,而有些则从真实用户或固定数据集采样。

### 1.5 终止条件(Done)

一个 episode 何时结束?对话环境中常见终止条件:
- 生成了特定的结束标记(如 `<|endoftext|>`);
- 达到最大对话轮数;
- 奖励模型判定回答已完全无帮助,需要提前终止;
- 用户发出了停止信号。

---

## 2. 使用 OpenAI Gym 风格搭建第一个 LLM 环境

为了让环境能接入标准强化学习算法(如 PPO),我们需要遵循类似 Gym 的接口规范。下面将逐步实现一个简易的“文本问答环境”。

### 2.1 环境类的骨架

```python
import gym
from gym import spaces
import random

class TextQAGym(gym.Env):
    def __init__(self, prompts, max_turns=5):
        super().__init__()
        self.prompts = prompts            # 候选问题列表
        self.max_turns = max_turns
        self.action_space = spaces.Discrete(100)  # 占位,实际动作是 token 序列
        self.observation_space = spaces.Text(max_length=512)  # 文本状态
        self.state = None
        self.turn = 0

    def _get_obs(self):
        return self.state

    def reset(self):
        self.turn = 0
        self.state = random.choice(self.prompts) + "\n助手:"
        return self._get_obs()

    def step(self, action_text):
        # action_text 是模型生成的文本
        self.state += action_text
        reward = self._compute_reward(action_text)  # 稍后实现
        self.turn += 1

        # 判断是否终止
        done = ("[EOT]" in action_text) or (self.turn >= self.max_turns)
        # 给新状态追加下一轮用户输入(简化版:从数据集取下一问题)
        if not done:
            next_prompt = random.choice(self.prompts)
            self.state += "\n用户:" + next_prompt + "\n助手:"

        info = {}
        return self._get_obs(), reward, done, info

2.2 动作如何变成文本?

上面的 step() 接收的是已生成的文本 action_text。但在实际训练中,智能体(语言模型)会输出 logits 或采样后的 token ID。环境通常不直接处理 token ID,而是将解码后的文本作为输入。这层转换由智能体-环境交互层负责。

2.3 环境的观测与状态设计

观测即当前可用的信息。在多轮对话中,观测可以仅仅是当前回合的提示,也可以是完整历史。建议初学者使用完整历史作为状态,因为模型需要理解上下文。


3. 奖励设计:从启发式规则到人类偏好模型

奖励函数是环境设计中最关键的组件,它直接决定了对齐的方向。常见三种设计路径:

3.1 启发式规则奖励

早期实验或极简版本中,可以用规则快速起步,例如:

  • 回答长度过短(<5 tokens)给负奖励;
  • 包含预设的礼貌用语(如“谢谢”)给正奖励;
  • 检测到敏感词给极大负奖励。
FORBIDDEN = ["攻击性词汇1", "攻击性词汇2"]
def _compute_reward(self, text):
    if any(word in text for word in FORBIDDEN):
        return -1.0
    if len(text) < 10:
        return -0.5
    return 0.1  # 微小正奖励鼓励对话继续

但这些规则太生硬,难以覆盖对齐的全部需求。

3.2 训练一个奖励模型(Reward Model)

RLHF 的标准做法是:收集人类对回复的偏好数据(两个回复中哪个更好),训练一个奖励模型。该模型的输入是(提示,回复),输出一个标量奖励。

在环境中,奖励函数就变成了调用这个预训练模型:

class RewardModelEnv(TextQAGym):
    def __init__(self, reward_model, tokenizer, **kwargs):
        super().__init__(**kwargs)
        self.reward_model = reward_model
        self.tokenizer = tokenizer

    def _compute_reward(self, text):
        # 拼接提示和回复,送入奖励模型
        full_text = self.state + text   # self.state 在追加前是原始提示
        inputs = self.tokenizer(full_text, return_tensors="pt")
        with torch.no_grad():
            score = self.reward_model(**inputs).logits.item()
        return score

3.3 端到端的人类在环奖励

更精细的对齐会让人类直接对当前输出给出即时评分(或者事后标注)。这在环境设计中通过模拟延迟奖励来实现:每次 step() 时从预标注好的评分表读取奖励,或接入人类打分 API。


4. 构建一个对齐专用的模拟对话世界

现在我们将上面的零件组装成一个完整的环境,模拟用户和助手之间的对话对齐。

4.1 模拟用户行为

真实用户输入是多样且动态的,但训练环境往往需要从已有数据集抽样。可以把用户看作环境的另一部分:每次助手回答后,环境从数据集选取一个新的用户问题,或者根据对话主题自动生成一个追问。

数据集抽样版:

def _next_user_input(self):
    # 从对话数据集抽取下一个用户发言
    return self.dialog_dataset.sample()

更逼真的做法是引入一个“用户模拟器”小模型,根据上下文生成追问。但初学者可先使用静态数据集。

4.2 终止条件与 episode 边界

终止条件直接影响训练稳定性。典型设置:

  • 最大轮数终止:防止对话无限循环。
  • 奖励过低终止:如果奖励模型的分数低于某个阈值,认为助手已产生有害输出,立即结束该 episode。
  • EOS 标记终止:模型生成特殊的终止标记。

4.3 完整环境示例:对齐对话 Gym

下面展示如何把奖励模型、用户模拟和终止条件集成到一个环境类中:

class AlignmentDialogueEnv(gym.Env):
    def __init__(self, prompt_list, reward_model, tokenizer,
                 max_turns=3, stop_token="<|end|>"):
        self.prompt_list = prompt_list
        self.reward_model = reward_model
        self.tokenizer = tokenizer
        self.max_turns = max_turns
        self.stop_token = stop_token

        self.action_space = spaces.Text(max_length=200)
        self.observation_space = spaces.Text(max_length=1024)

        self.state = None
        self.turn = 0

    def reset(self):
        self.turn = 0
        prompt = random.choice(self.prompt_list)
        self.state = f"用户:{prompt}\n助手:"
        return self.state

    def step(self, action_text):
        # 1. 计算奖励
        reward = self._compute_reward(action_text)

        # 2. 追加助手回复
        self.state += action_text
        self.turn += 1

        # 3. 判断终止
        done = False
        if self.stop_token in action_text:
            done = True
        elif self.turn >= self.max_turns:
            done = True

        # 4. 准备下一状态
        if not done:
            next_user = random.choice(self.prompt_list)  # 简化用户输入
            self.state += f"\n用户:{next_user}\n助手:"

        info = {}
        return self.state, reward, done, info

    def _compute_reward(self, text):
        # 使用奖励模型评分
        prompt_end_idx = self.state.find("助手:") + 3
        prompt_text = self.state[:prompt_end_idx]
        model_input = prompt_text + text
        inputs = self.tokenizer(model_input, return_tensors="pt")
        with torch.no_grad():
            score = self.reward_model(**inputs).logits.item()
        return score

4.4 与 PPO 训练循环的集成伪代码

env = AlignmentDialogueEnv(prompts, reward_model, tokenizer)
obs = env.reset()
for update in range(ppo_updates):
    # 收集一批 rollout 数据
    batch_obs, batch_actions, batch_rewards, batch_dones = [], [], [], []
    for step in range(rollout_steps):
        action = agent.sample(obs)         # agent 是带策略头的 LLM
        next_obs, reward, done, _ = env.step(action)
        batch_obs.append(obs)
        batch_actions.append(action)
        batch_rewards.append(reward)
        batch_dones.append(done)

        obs = next_obs
        if done:
            obs = env.reset()
    # 用 PPO 更新智能体
    agent.update(batch_obs, batch_actions, batch_rewards, batch_dones)