SASRec:基于自注意力的序列推荐模型

FreeGuideOnline 最新 2026-06-23

输入 = E + P[:n, :]


即物品嵌入与对应位置嵌入相加(逐元素求和)。值得注意的是,SASRec并未采用Transformer中的正弦位置编码,而是选择**可学习**的位置嵌入,实验表明在推荐任务中适应性更好。

## 4. 自注意力块:多头的受限注意力

### 4.1 缩放点积注意力
给定查询(Q)、键(K)、值(V),注意力计算公式为:

Attention(Q, K, V) = softmax( (QK^T)/√d_k + M ) V


其中 `M` 为**因果关系掩码**,是一个下三角矩阵(允许看到本身),将未来位置对应的注意力分数设为 -∞,softmax后权重接近0,保证训练时预测第`t`个物品只使用前`t-1`个物品的信息。

### 4.2 多头注意力
SASRec使用多头机制,将 `d` 维的嵌入拆分到 `h` 个不同的子空间,独立计算注意力后拼接,再通过线性变换映射回原维度。这使模型能关注不同子空间中的复杂行为模式。

```python
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)

4.3 前馈网络与残差连接

每个自注意力子层后跟一个 Position-wise 的前馈网络(两层全连接,ReLU激活),并采用残差连接(Add)和层归一化(Norm)稳定训练:

F = Attention(LayerNorm(S)) + S
Output = FFN(LayerNorm(F)) + F

注意:SASRec 采用先层归一化再进入子层的预归一化(Pre-LN)结构,与原Transformer后归一化有细微差别,但在实践中更利于训练。

5. 堆叠多个自注意力块

一个自注意力块由一层多头自注意力加一层前馈网络组成。SASRec 会堆叠 B 个这样的块,逐步学习更高阶的物品转换模式。经过 B 层之后,得到最终每个时间步的隐状态表示 H ∈ R^{n×d}。通常取最后一个位置的输出 h_n 作为用户当前序列的聚合表示(也可取所有位置做平均,但原文使用最后一个)。

6. 预测层与训练目标

将最后一个位置的隐向量 h_n 与物品嵌入矩阵 M 计算内积,再经过 softmax 得到下一个物品的概率分布:

P(v_{n+1}=i) = softmax( M_i · h_n^T )

注意这里使用了共享嵌入权重:预测时的权重矩阵与输入时的物品嵌入矩阵相同。这不仅可以减少参数量,还能对物品表示形成有效的正则化,类似 word2vec 的思想。

损失函数

SASRec 采用二进制交叉熵(Binary Cross Entropy)负对数似然(Negative Log-Likelihood)。通常训练时对每个正样本(真实的下一个物品)随机采样一个或几个负样本,并将任务视为二元分类(区分正负样本)以提高效率。也可以使用全物品 softmax,但物品数量巨大时计算代价高,通常配合采样 softmax。

7. SASRec 与其他模型的对比

模型类型 代表模型 核心机制 优势 劣势
RNN GRU4Rec 循环网络 天然处理序列顺序 无法并行,捕捉长期依赖困难
CNN Caser 水平/垂直卷积 捕获局部跳跃模式 视野受卷积核大小限制
注意力 SASRec 自注意力 并行、长程依赖、可解释 对序列长度敏感,需要掩码

实验证明,SASRec在稀疏数据集和稠密数据集上均能超越或持平当时最强基线,尤其对长序列更有优势。

8. 从零实现 SASRec(简版 PyTorch 代码框架)

下面展示一个简化的 SASRec 实现,突出核心模块。省略了数据预处理和训练循环。

import torch
import torch.nn as nn
import torch.nn.functional as F

class SASRec(nn.Module):
    def __init__(self, item_num, max_len, embed_dim, num_heads, num_blocks, dropout=0.1):
        super().__init__()
        self.item_emb = nn.Embedding(item_num + 1, embed_dim, padding_idx=0)  # 0为填充
        self.pos_emb = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.encoder_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout) for _ in range(num_blocks)
        ])
        self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, seq, pos_ids):
        # seq: (batch, max_len)
        seq_emb = self.item_emb(seq)  # (B, L, d)
        pos_emb = self.pos_emb(pos_ids)  # (B, L, d)
        x = self.dropout(seq_emb + pos_emb)

        mask = torch.tril(torch.ones(seq.size(1), seq.size(1)), diagonal=0).to(seq.device)
        mask = mask == 0  # True for positions to mask

        for block in self.encoder_blocks:
            x = block(x, mask)
        x = self.layer_norm(x)

        # 取最后一个位置的输出作为序列表示
        last_hidden = x[:, -1, :]  # (B, d)
        return last_hidden

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Pre-LN + Multihead Attention (self-attention)
        attn_out, _ = self.attention(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask)
        x = x + self.dropout(attn_out)
        # Pre-LN + Feed Forward
        ff_out = self.feed_forward(self.ln2(x))
        x = x + self.dropout(ff_out)
        return x

预测时,可利用last_hidden与物品嵌入矩阵做内积(如果共享权重,则直接torch.matmul(last_hidden, self.item_emb.weight.T))并计算损失。采样softmax可用torch.nn.CrossEntropyLoss配合负采样实现。

9. 训练与评估实践建议

9.1 数据预处理

  • 按用户分组构建交互序列,按时间排序。
  • 序列截断:固定最大长度 N,长序列截断最近的交互(保留尾部),短序列左侧补零。
  • 训练时采用滑动窗口生成(input_seq, target_item)对,目标为input_seq的下一个物品。

9.2 负采样策略

由于物品数量庞大,全量softmax计算代价高。推荐使用批次内负采样随机负采样。每个正样本配对1~5个负样本,采用BCE损失:

loss = - (log(σ(score_pos)) + Σ log(1 - σ(score_neg)))