P-Tuning:基于可学习提示的连续空间微调

FreeGuideOnline 最新 2026-06-22

P-Tuning:让大模型听懂你的“连续密码”

P-Tuning 是一种参数高效的大语言模型微调方法。它不再依赖人工设计离散的提示词,而是在模型的输入层插入一组可学习的连续向量,通过反向传播自动寻找最能激发模型能力的“虚拟提示”。这种方法大幅降低了使用大模型的门槛,只需调整极少参数,就能让同一模型出色地完成分类、生成、推理等不同任务。


为什么需要 P-Tuning:从离散提示到连续空间

传统提示工程的困境

在 GPT-3 等大模型兴起后,提示工程成为常用的任务适配手段。但手工编写提示词存在明显局限:

  • 表达瓶颈:人类语言的离散词汇不一定能精确映射模型内部的知识结构。
  • 稳定性差:同一含义换一种表述,模型表现可能剧烈波动。
  • 搜索困难:提示词的组合空间巨大,人工试错成本极高。
  • 长尾任务:对于冷门领域或复杂推理,很难靠直觉设计出高效提示。

自动化提示的演进

研究者开始探索让模型自己学习“提示”。早期方法如 AutoPrompt 在离散词汇中搜索最优提示,但搜索过程昂贵且仍受限于离散解空间。P-Tuning 的创新在于将提示从离散 token 空间解放到连续 embedding 空间,让模型能以任意实数向量作为提示,从而摆脱人类语言的束缚。


核心思想:把提示变成可训练的参数

P-Tuning 将“提示”视作一段可微分的连续向量序列,直接拼接在真实输入的 embedding 前面,整个流程如下:

  1. 固定大模型参数:原始预训练语言模型(如 BERT、GPT)的全部权重冻结,不参与梯度更新。
  2. 插入可学习的 Prompt Embeddings:在输入层引入若干可训练的向量,称为“虚拟 token”。这些向量的维度与模型词嵌入完全相同。
  3. 使用 LSTM 或 MLP 进行编码(可选):为了增强虚拟 token 之间的交互,P-Tuning 引入一个轻量的双向 LSTM 或小型 MLP 网络作为提示编码器,将独立的 prompt 向量映射为上下文相关的表示。
  4. 与真实输入拼接:编码后的 prompt embeddings 拼接在真实输入序列 embedding 的前面,形成完整的输入序列,送入冻结的模型主干。
  5. 端到端微调提示参数:仅优化这些新增的 prompt embeddings 和提示编码器参数(若使用),通过下游任务的损失函数反向传播,让模型自己学会最适合当前任务的“连续提示”。

与传统微调的对比

对比维度 全参数微调 P-Tuning
可训练参数量 全部模型参数(数十亿) 仅有 prompt 向量和编码器(百万级)
存储成本 每个任务需保存一份完整模型 每个任务只需保存新增的 prompt 字典
训练速度 慢,需大量 GPU 资源 快,低资源设备也可训练
灾难性遗忘 容易丢失通用能力 主干冻结,完美保留预训练知识
可解释性 参数变化难以解读 可将连续提示映射回最近邻词汇,获得近似解读

模型架构拆解

P-Tuning 的结构可以拆为三个关键组件:

1. 连续提示嵌入层

假设模型原来的词表大小为 ( V ),嵌入维度为 ( d )。P-Tuning 会初始化一个形状为 ((p, d)) 的可学习矩阵 ( P ),其中 ( p ) 是虚拟 token 的数量(通常取 5~20)。这些向量随机初始化,类似新增了 ( p ) 个“单词”。

2. 提示编码器

直接将独立的 prompt 向量拼接到输入序列,会导致每个位置之间缺乏依赖关系。为此,P-Tuning 使用一个轻量级的双向 LSTM 或两层 MLP 作为提示编码器 ( f_{enc} ),将 ( P ) 映射为 ( P' = f_{enc}(P) )。编码器的输出维度仍为 ( d )。

  • 为什么需要编码器:直接让每个虚拟 token 独立学习,容易陷入次优解。编码器引入了 token 间的位置交互,使学习更稳定。
  • 编码器只在训练时存在:推理阶段,可以将编码器的输出缓存为静态向量,无需重复运行,消除额外推理延迟。

3. 输入拼接与模型前向

给定原始输入序列 ( X = [x_1, x_2, ..., x_n] ),其 embedding 为 ( [e(x_1), ..., e(x_n)] )。构造新的输入为:

[P'_1, P'_2, ..., P'_p, e(x_1), e(x_2), ..., e(x_n)]

将此序列输入冻结的主干模型,获取输出表示。对于分类任务,通常取第一个 token 的最终隐状态送入线性分类头;对于生成任务,直接参与自回归解码。


P-Tuning 实战:从原理到代码

以下使用 Hugging Face Transformers 的 p-tuning 风格流程演示关键步骤,假设使用 BERT 进行文本分类。

环境与模型准备

import torch
from transformers import BertModel, BertTokenizer

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

# 冻结主模型所有参数
for param in model.parameters():
    param.requires_grad = False

定义 P-Tuning 模块

class PromptEncoder(torch.nn.Module):
    def __init__(self, prompt_length, hidden_size, lstm_hidden=128):
        super().__init__()
        self.lstm = torch.nn.LSTM(
            input_size=hidden_size,
            hidden_size=lstm_hidden,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(2 * lstm_hidden, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, hidden_size)
        )

    def forward(self, prompt_embeddings):
        # prompt_embeddings: (batch, prompt_length, hidden_size)
        out, _ = self.lstm(prompt_embeddings)
        return self.mlp(out)

组合完整模型

class PtuningBertForClassification(torch.nn.Module):
    def __init__(self, bert, num_labels, prompt_length=10):
        super().__init__()
        self.bert = bert
        self.config = bert.config
        hidden_size = self.config.hidden_size
        # 可学习的 prompt embeddings
        self.prompt_embeddings = torch.nn.Parameter(
            torch.randn(prompt_length, hidden_size)
        )
        self.prompt_encoder = PromptEncoder(prompt_length, hidden_size)
        self.classifier = torch.nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        batch_size = input_ids.size(0)
        # 真实 token 的词嵌入
        inputs_embeds = self.bert.embeddings.word_embeddings(input_ids) # 省略 position embedding 等处理,实际需完整获取
        # 扩展 prompt 至 batch 维度
        prompt = self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
        prompt = self.prompt_encoder(prompt)
        # 拼接 prompt + 输入
        inputs_embeds = torch.cat([prompt, inputs_embeds], dim=1)
        # 更新 attention_mask,为 prompt 部分添加全 1
        prompt_mask = torch.ones(batch_size, prompt_length, dtype=attention_mask.dtype, device=input_ids.device)
        attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)
        # 通过冻结的 BERT
        outputs = self.bert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        # 取 [CLS] 位置(prompt 后的第一个真实 token)或 prompt 后第一个 token 用于分类
        cls_output = outputs.last_hidden_state[:, prompt_length, :]
        logits = self.classifier(cls_output)
        return logits

训练与优化

  • 优化器:仅优化 prompt_embeddin] 不过需声明可训练部分,上面代码中 param.requires_grad 已设 Falseprompt_embeddingsprompt_encoderclassifier`。
  • 学习率:典型学习率设为 1e-3 至 5e-5,比全参数微调稍高。
  • 提示长度:一般 5~20,任务复杂时可适当增加。
  • 验证策略:可以使用多个随机种子初始化 prompt 权重,选择验证集表现最佳的检查点。

进阶理解:连续提示背后的秘密

为什么连续向量能替代离散提示?

可以将大模型的输入空间看作一个高维流形,离散词嵌入只是流形上的若干锚点。连续提示向量可以在流形的任意位置移动,从而表达出词汇表无法精确描述的概念组合。训练过程将这个向量拉伸到最能激活任务相关知识的区域。

编码器的作用与取舍

  • LSTM vs MLP:LSTM 对长提示序列建模能力更强,MLP 效率更高。实验表明,当提示长度较短时,MLP 效果良好。
  • 推理阶段移除编码器:训练完成后,将编码器的输出缓存为固定 prompt 数组。输入时直接读取静态向量,完全去除了编码器的计算开销,使推理速度与原始模型一致。

P-Tuning 的变体与演进

  • P-Tuning v2:将连续提示插入到模型的每一层,而不仅是输入层,进一步提升了小模型上的表现,并弥合了与全参数微调的差距。
  • Prefix Tuning:与 P-Tuning 同期提出,为自回归模型(如 GPT)的每层添加可学习的 prefix 向量,结构和用途高度相似,均属于“提示微调”范式。
  • LoRA:采用低秩矩阵更新模型权重,属于另一类参数高效微调,可与 P-Tuning 互补使用。

优势与局限全景分析

核心优势

  • 极致的参数效率:百万级参数即可适配新任务,数百个任务仅需数十 MB 存储空间。
  • 完美保留通用知识:模型主体不变,杜绝灾难性遗忘,可放心在多个下游任务间切换。
  • 训练资源大幅降低:单卡消费级 GPU 即可快速完成训练,显存占用仅为全参数微调的几十分之一。
  • 支持多任务部署:只需动态切换 prompt 文件,同一份模型权重就能应对不同场景。

当前局限

  • 小模型效果有限:对于参数量小于 1B 的模型,P-Tuning 的表现仍与全参数微调有差距(P-Tuning v2 部分缓解)。
  • 超参敏感:提示长度、学习率、编码器结构对最终效果影响较大,需要少量调参。
  • 解释性弱于离散提示:虽然可以将连续向量映射回最近邻词汇,但语义往往模糊不清,难以直接用于调试。
  • 生成任务需特殊设计:对于长文本生成,仅在输入层加入提示可能不够,需结合 prefix tuning 等结构。

适用场景与落地建议

  • 文本分类、情感分析、自然语言推断:P-Tuning 的最典型应用,只需极少量标注数据即可逼近全参数微调。
  • 知识密集型问答:通过为每个知识域训练独立 prompt,快速构建领域适配模块。
  • 多任务学习平台:一个模型底座配多个 prompt 文件,以插件形式服务不同业务线。
  • 尽量避免场景:需要大幅改变模型输出风格的生成任务、需要实时在多个提示间切换的极低延迟系统(此时缓存静态向量较合适)。

开始使用 P-Tuning,推荐从 Hugging Face PEFT 库中的 PromptEncoder 接口入手,它封装了完整的 P-Tuning 逻辑,可几行代码与你手中的预训练模型结合。无需昂贵硬件,你就能打开大模型微调的新大门。