司法判决预测:基于案件事实与法条的多任务预测

FreeGuideOnline 最新 2026-06-26

样本示例: 事实:被告人张三于2022年3月12日凌晨,翻窗进入某小区302室,盗走笔记本电脑一台、现金人民币5000元。 法条:[刑法第264条] 罪名:盗窃罪 刑期:24个月


数据集常用的是 CAIL2018 (中国裁判文书网公开数据)或类似开源基准。

---

## 为什么需要多任务学习

单独为罪名和法条训练两个分类器会忽略二者之间的强逻辑关系。例如,法条 `第264条` 唯一对应 `盗窃罪`,而某些罪名可能对应多个法条。多任务联合建模的优点包括:

- **语义共享**:事实编码器的表示同时服务于所有任务,减少重复计算,提高泛化能力。
- **任务约束**:通过任务间的依赖建模(如法条决定罪名),可消除不符合法律逻辑的预测组合。
- **数据不平衡缓解**:低频法条或罪名可通过高频关联任务获得更好的学习信号。

本教程将实现一种经典的“硬参数共享”多任务架构,并引入法条 -> 罪名的显式依赖建模。

---

## 整体架构设计

系统由三大模块构成:

1. **文本编码器**:将事实文本转化为固定长度的语义向量。
2. **任务特有分类头**:三个独立的输出层,分别预测法条、罪名、刑期。
3. **依赖注入层**:法条预测结果会被作为额外特征输入到罪名预测中,模拟判决逻辑。

事实文本 --> [编码器] --> 共享表示 | +--> 法条分类头 --> 法条概率 +--> 依赖融合 --> 罪名分类头 --> 罪名概率 +--> 刑期回归头 --> 刑期数值


---

## 从零实现

我们将基于 PyTorch 构建一个可运行的最小化 demo。假设你已安装 `torch`、`transformers` 和 `scikit-learn`。

### 步骤1:数据预处理

使用一个简化的合成数据集模拟真实格式。实际应用时,你需要解析 CAIL 的 JSON 文件。

```python
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer

class LJP_Dataset(Dataset):
    def __init__(self, facts, articles, charges, penalties, tokenizer, max_len=512):
        self.facts = facts
        self.articles = articles
        self.charges = charges
        self.penalties = penalties
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.facts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.facts[idx],
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'article': torch.tensor(self.articles[idx]),   # 法条 id
            'charge': torch.tensor(self.charges[idx]),     # 罪名 id
            'penalty': torch.tensor(self.penalties[idx], dtype=torch.float)  # 刑期值
        }

步骤2:模型定义

采用 Bernal-NLP 作为共享编码器(BERT-base-chinese),上面接三个任务头。同时定义依赖注入:将法条预测的 logits 经过一个线性层后加到罪名分类头的输入中。

from transformers import BertModel
import torch.nn as nn

class MultiTaskLJP(nn.Module):
    def __init__(self, pretrained_model_name, num_articles, num_charges):
        super().__init__()
        self.encoder = BertModel.from_pretrained(pretrained_model_name)
        hidden_size = self.encoder.config.hidden_size

        # 法条分类头
        self.article_clf = nn.Linear(hidden_size, num_articles)
        # 法条到罪名的转换矩阵
        self.article_to_charge = nn.Linear(num_articles, hidden_size)
        # 罪名分类头(额外接收法条信息)
        self.charge_clf = nn.Linear(hidden_size * 2, num_charges)
        # 刑期回归头
        self.penalty_reg = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, input_ids, attention_mask):
        # 共享编码
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]   # [batch, hidden]

        # 法条预测
        article_logits = self.article_clf(cls_emb)     # [batch, num_articles]
        # 法条信息注入
        article_info = self.article_to_charge(article_logits)  # [batch, hidden]
        # 罪名预测(拼接原始表示和法条信息)
        charge_input = torch.cat([cls_emb, article_info], dim=-1)
        charge_logits = self.charge_clf(charge_input)  # [batch, num_charges]

        # 刑期预测(回归,输出正数)
        penalty_pred = self.penalty_reg(cls_emb).squeeze(-1)   # [batch]
        penalty_pred = torch.clamp(penalty_pred, min=0)

        return article_logits, charge_logits, penalty_pred

步骤3:损失函数与训练

三个子任务对应不同损失:法条和罪名使用交叉熵(分类),刑期使用均方误差(回归)。总损失为三者加权和。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiTaskLJP('bert-base-chinese', num_articles=100, num_charges=80).to(device)

criterion_cls = nn.CrossEntropyLoss()
criterion_reg = nn.MSELoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

def train_epoch(dataloader, model, optimizer):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        art_label = batch['article'].to(device)
        chg_label = batch['charge'].to(device)
        pen_label = batch['penalty'].to(device)

        optimizer.zero_grad()
        art_logits, chg_logits, pen_pred = model(input_ids, attention_mask)

        loss_art = criterion_cls(art_logits, art_label)
        loss_chg = criterion_cls(chg_logits, chg_label)
        loss_pen = criterion_reg(pen_pred, pen_label)

        # 加权超参,可按需要调整
        loss = loss_art + loss_chg + 0.1 * loss_pen
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

步骤4:评估指标

  • 法条/罪名:采用微观 F1 和准确率。
  • 刑期:采用平均绝对误差(MAE)或均方根误差。
from sklearn.metrics import f1_score, accuracy_score, mean_absolute_error

def evaluate(model, dataloader):
    model.eval()
    art_preds, art_labels = [], []
    chg_preds, chg_labels = [], []
    pen_preds, pen_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            art_logits, chg_logits, pen_pred = model(input_ids, attention_mask)

            art_preds.extend(torch.argmax(art_logits, dim=1).cpu().tolist())
            art_labels.extend(batch['article'].tolist())
            chg_preds.extend(torch.argmax(chg_logits, dim=1).cpu().tolist())
            chg_labels.extend(batch['charge'].tolist())
            pen_preds.extend(pen_pred.cpu().tolist())
            pen_labels.extend(batch['penalty'].tolist())

    art_f1 = f1_score(art_labels, art_preds, average='micro')
    art_acc = accuracy_score(art_labels, art_preds)
    chg_f1 = f1_score(chg_labels, chg_preds, average='micro')
    chg_acc = accuracy_score(chg_labels, chg_preds)
    pen_mae = mean_absolute_error(pen_labels, pen_preds)

    return {'art_f1': art_f1, 'art_acc': art_acc,
            'chg_f1': chg_f1, 'chg_acc': chg_acc,
            'pen_mae': pen_mae}

进阶技巧与真实场景优化

1. 拓扑依赖约束

更精细的架构会直接利用法条与罪名之间的层级图。可构建一个拓扑 LSTM 或图卷积网络,强制法条隐状态顺序影响罪名预测,而非简单拼接。

2. 刑期区间化

刑期预测常见的做法是将其建模为分类问题(如0-6月、6-12月等区间),因为刑期分布极不均匀且存在无期、死刑等特殊值。使用有序回归或 Ordinal Regression 能提升效果。

3. 法条推荐作为辅助

部分案件会涉及多条法条,此时需要将其作为多标签分类。可采用二元交叉熵损失,并在罪名预测头中利用法条的多热向量。

4. 预训练法律语言模型

使用 Lawformer、Legal-BERT 等在法律文书上预训练的模型替换通用 BERT,可大幅提升低资源法条和罪名的识别能力。

5. 数据增强与对抗训练

针对低频类别,可利用同义词替换、回译等方式生成更多事实文本。对抗训练(FGM)可增强模型的鲁棒性。


完整代码总结

上述代码片段整合后即可运行一个基础的多任务判决预测模型。完整项目结构建议如下:

ljp_project/
│
├── data/
│   └── cail_small.json
├── dataset.py
├── model.py
├── train.py
├── eval.py
└── utils.py