AutoInt:用自注意力学习特征间的高阶交互

FreeGuideOnline 最新 2026-06-24

python import torch import torch.nn as nn import torch.nn.functional as F

模拟数据:batch_size=4, 3个特征域,每个域的特征数分别为100,50,20

batch_size = 4 field_dims = [100, 50, 20] # 每个特征域的词汇量 embed_dim = 16 # 嵌入维度 num_fields = len(field_dims)

随机输入,每个特征域的值是相应词汇量下的int

x = torch.randint(0, max(field_dims), (batch_size, num_fields)) print(x.shape) # [4, 3]


### 2. 嵌入层与位置编码

```python
class FeaturesEmbedding(nn.Module):
    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.embedding = nn.ModuleList([
            nn.Embedding(dim, embed_dim) for dim in field_dims
        ])

    def forward(self, x):
        # x: [batch_size, num_fields]
        embs = [self.embedding[i](x[:, i]) for i in range(len(field_dims))]
        # 每个嵌入是 [batch_size, embed_dim],stack成 [batch_size, num_fields, embed_dim]
        return torch.stack(embs, dim=1)

位置编码采用可学习的参数:

class PositionalEncoding(nn.Module):
    def __init__(self, num_fields, embed_dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.randn(num_fields, embed_dim))

    def forward(self, x):
        # x: [batch_size, num_fields, embed_dim]
        return x + self.pos_embedding.unsqueeze(0)

3. 多头自注意力交互层

省略FFN,直接使用多头注意力+残差+LayerNorm。

class MultiHeadAttentionInteraction(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [batch_size, num_fields, d_model]
        residual = x
        batch_size, n, _ = x.size()

        Q = self.W_Q(x).view(batch_size, n, self.n_heads, self.d_k).transpose(1,2)  # [B, h, n, d_k]
        K = self.W_K(x).view(batch_size, n, self.n_heads, self.d_k).transpose(1,2)
        V = self.W_V(x).view(batch_size, n, self.n_heads, self.d_k).transpose(1,2)

        # 缩放点积注意力
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)  # [B, h, n, n]
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, V)  # [B, h, n, d_k]
        # 将多头拼起来
        attn_output = attn_output.transpose(1,2).contiguous().view(batch_size, n, self.d_model)
        attn_output = self.W_O(attn_output)

        out = self.norm(residual + attn_output)
        return out

可以堆叠多个交互层:

class InteractionStack(nn.Module):
    def __init__(self, num_layers, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            MultiHeadAttentionInteraction(d_model, n_heads, dropout) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

4. 输出层(含一阶线性部分)

一阶逻辑回归部分直接对原始稀疏特征加权,与自注意力交互结果相加。

class AutoInt(nn.Module):
    def __init__(self, field_dims, embed_dim, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.embedding = FeaturesEmbedding(field_dims, embed_dim)
        self.pos_encoding = PositionalEncoding(len(field_dims), embed_dim)
        self.interaction = InteractionStack(num_layers, embed_dim, num_heads, dropout)
        # 一阶线性部分
        self.linear = nn.Embedding(sum(field_dims), 1)
        self.bias = nn.Parameter(torch.zeros(1))
        # 最终的加权组合
        self.fc = nn.Linear(embed_dim * len(field_dims), 1)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, 0, 0.01)

    def forward(self, x):
        # x: [batch_size, num_fields]
        # 线性部分
        linear_part = self.linear(x).sum(dim=1) + self.bias  # [B, 1]

        # 自注意力交互部分
        emb = self.embedding(x)           # [B, n, d]
        emb = self.pos_encoding(emb)
        inter_out = self.interaction(emb) # [B, n, d]
        # 展平所有特征域的表示
        flatten = inter_out.view(inter_out.size(0), -1)  # [B, n*d]
        interaction_part = self.fc(flatten)              # [B, 1]

        # 融合
        logit = linear_part + interaction_part
        return torch.sigmoid(logit.squeeze(1))

测试模型:

model = AutoInt(field_dims, embed_dim=16, num_heads=4, num_layers=2)
print(model)
y = model(x)
print(y.shape)  # [4]