司法判决预测:基于案件事实与法条的多任务预测
FreeGuideOnline
最新
2026-06-26
样本示例: 事实:被告人张三于2022年3月12日凌晨,翻窗进入某小区302室,盗走笔记本电脑一台、现金人民币5000元。 法条:[刑法第264条] 罪名:盗窃罪 刑期:24个月
数据集常用的是 CAIL2018 (中国裁判文书网公开数据)或类似开源基准。
---
## 为什么需要多任务学习
单独为罪名和法条训练两个分类器会忽略二者之间的强逻辑关系。例如,法条 `第264条` 唯一对应 `盗窃罪`,而某些罪名可能对应多个法条。多任务联合建模的优点包括:
- **语义共享**:事实编码器的表示同时服务于所有任务,减少重复计算,提高泛化能力。
- **任务约束**:通过任务间的依赖建模(如法条决定罪名),可消除不符合法律逻辑的预测组合。
- **数据不平衡缓解**:低频法条或罪名可通过高频关联任务获得更好的学习信号。
本教程将实现一种经典的“硬参数共享”多任务架构,并引入法条 -> 罪名的显式依赖建模。
---
## 整体架构设计
系统由三大模块构成:
1. **文本编码器**:将事实文本转化为固定长度的语义向量。
2. **任务特有分类头**:三个独立的输出层,分别预测法条、罪名、刑期。
3. **依赖注入层**:法条预测结果会被作为额外特征输入到罪名预测中,模拟判决逻辑。
事实文本 --> [编码器] --> 共享表示 | +--> 法条分类头 --> 法条概率 +--> 依赖融合 --> 罪名分类头 --> 罪名概率 +--> 刑期回归头 --> 刑期数值
---
## 从零实现
我们将基于 PyTorch 构建一个可运行的最小化 demo。假设你已安装 `torch`、`transformers` 和 `scikit-learn`。
### 步骤1:数据预处理
使用一个简化的合成数据集模拟真实格式。实际应用时,你需要解析 CAIL 的 JSON 文件。
```python
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
class LJP_Dataset(Dataset):
def __init__(self, facts, articles, charges, penalties, tokenizer, max_len=512):
self.facts = facts
self.articles = articles
self.charges = charges
self.penalties = penalties
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.facts)
def __getitem__(self, idx):
encoding = self.tokenizer(
self.facts[idx],
truncation=True,
padding='max_length',
max_length=self.max_len,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'article': torch.tensor(self.articles[idx]), # 法条 id
'charge': torch.tensor(self.charges[idx]), # 罪名 id
'penalty': torch.tensor(self.penalties[idx], dtype=torch.float) # 刑期值
}
步骤2:模型定义
采用 Bernal-NLP 作为共享编码器(BERT-base-chinese),上面接三个任务头。同时定义依赖注入:将法条预测的 logits 经过一个线性层后加到罪名分类头的输入中。
from transformers import BertModel
import torch.nn as nn
class MultiTaskLJP(nn.Module):
def __init__(self, pretrained_model_name, num_articles, num_charges):
super().__init__()
self.encoder = BertModel.from_pretrained(pretrained_model_name)
hidden_size = self.encoder.config.hidden_size
# 法条分类头
self.article_clf = nn.Linear(hidden_size, num_articles)
# 法条到罪名的转换矩阵
self.article_to_charge = nn.Linear(num_articles, hidden_size)
# 罪名分类头(额外接收法条信息)
self.charge_clf = nn.Linear(hidden_size * 2, num_charges)
# 刑期回归头
self.penalty_reg = nn.Sequential(
nn.Linear(hidden_size, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, input_ids, attention_mask):
# 共享编码
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_emb = outputs.last_hidden_state[:, 0, :] # [batch, hidden]
# 法条预测
article_logits = self.article_clf(cls_emb) # [batch, num_articles]
# 法条信息注入
article_info = self.article_to_charge(article_logits) # [batch, hidden]
# 罪名预测(拼接原始表示和法条信息)
charge_input = torch.cat([cls_emb, article_info], dim=-1)
charge_logits = self.charge_clf(charge_input) # [batch, num_charges]
# 刑期预测(回归,输出正数)
penalty_pred = self.penalty_reg(cls_emb).squeeze(-1) # [batch]
penalty_pred = torch.clamp(penalty_pred, min=0)
return article_logits, charge_logits, penalty_pred
步骤3:损失函数与训练
三个子任务对应不同损失:法条和罪名使用交叉熵(分类),刑期使用均方误差(回归)。总损失为三者加权和。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiTaskLJP('bert-base-chinese', num_articles=100, num_charges=80).to(device)
criterion_cls = nn.CrossEntropyLoss()
criterion_reg = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
def train_epoch(dataloader, model, optimizer):
model.train()
total_loss = 0
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
art_label = batch['article'].to(device)
chg_label = batch['charge'].to(device)
pen_label = batch['penalty'].to(device)
optimizer.zero_grad()
art_logits, chg_logits, pen_pred = model(input_ids, attention_mask)
loss_art = criterion_cls(art_logits, art_label)
loss_chg = criterion_cls(chg_logits, chg_label)
loss_pen = criterion_reg(pen_pred, pen_label)
# 加权超参,可按需要调整
loss = loss_art + loss_chg + 0.1 * loss_pen
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
步骤4:评估指标
- 法条/罪名:采用微观 F1 和准确率。
- 刑期:采用平均绝对误差(MAE)或均方根误差。
from sklearn.metrics import f1_score, accuracy_score, mean_absolute_error
def evaluate(model, dataloader):
model.eval()
art_preds, art_labels = [], []
chg_preds, chg_labels = [], []
pen_preds, pen_labels = [], []
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
art_logits, chg_logits, pen_pred = model(input_ids, attention_mask)
art_preds.extend(torch.argmax(art_logits, dim=1).cpu().tolist())
art_labels.extend(batch['article'].tolist())
chg_preds.extend(torch.argmax(chg_logits, dim=1).cpu().tolist())
chg_labels.extend(batch['charge'].tolist())
pen_preds.extend(pen_pred.cpu().tolist())
pen_labels.extend(batch['penalty'].tolist())
art_f1 = f1_score(art_labels, art_preds, average='micro')
art_acc = accuracy_score(art_labels, art_preds)
chg_f1 = f1_score(chg_labels, chg_preds, average='micro')
chg_acc = accuracy_score(chg_labels, chg_preds)
pen_mae = mean_absolute_error(pen_labels, pen_preds)
return {'art_f1': art_f1, 'art_acc': art_acc,
'chg_f1': chg_f1, 'chg_acc': chg_acc,
'pen_mae': pen_mae}
进阶技巧与真实场景优化
1. 拓扑依赖约束
更精细的架构会直接利用法条与罪名之间的层级图。可构建一个拓扑 LSTM 或图卷积网络,强制法条隐状态顺序影响罪名预测,而非简单拼接。
2. 刑期区间化
刑期预测常见的做法是将其建模为分类问题(如0-6月、6-12月等区间),因为刑期分布极不均匀且存在无期、死刑等特殊值。使用有序回归或 Ordinal Regression 能提升效果。
3. 法条推荐作为辅助
部分案件会涉及多条法条,此时需要将其作为多标签分类。可采用二元交叉熵损失,并在罪名预测头中利用法条的多热向量。
4. 预训练法律语言模型
使用 Lawformer、Legal-BERT 等在法律文书上预训练的模型替换通用 BERT,可大幅提升低资源法条和罪名的识别能力。
5. 数据增强与对抗训练
针对低频类别,可利用同义词替换、回译等方式生成更多事实文本。对抗训练(FGM)可增强模型的鲁棒性。
完整代码总结
上述代码片段整合后即可运行一个基础的多任务判决预测模型。完整项目结构建议如下:
ljp_project/
│
├── data/
│ └── cail_small.json
├── dataset.py
├── model.py
├── train.py
├── eval.py
└── utils.py