
《动手学深度学习》(PyTorch版)代码注释 - 56 【Machine_translation】
编码器(Encoder):处理源语言序列,输出隐藏状态。 注意力机制:计算源语言序列对目标语言序列的注意力权重。 解码器(Decoder):处理目标语言序列,结合注意力输出生成翻译结果。 损失计算:在训练过程中计算并优化模型参数。
发布日期:2021-05-19 18:03:24
浏览次数:21
分类:精选文章
本文共 3979 字,大约阅读时间需要 13 分钟。
PyTorch 实现机器翻译模型
本节对应书本上,此节功能为机器翻译
由于此节相对复杂,代码注释量较多配置环境
在此基础上,本人采用以下开发环境进行代码编写和训练:
- 使用环境:Python 3.8
- 平台:Windows 10
- IDE:PyCharm
代码解析
本节将详细解析机器翻译模型的实现,涉及即时对话机器翻译(mam需要使用注意力机制)的训练与预测流程。以下是完整的代码实现:
class Encoder(nn.Module): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, drop_prob): super(Encoder, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=drop_prob) def forward(self, inputs, state): embedding = self.embedding(inputs.long()).permute(1, 0, 2) return self.rnn(embedding, state) def begin_state(self): return None
class Decoder(nn.Module): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, attention_size, drop_prob): super(Decoder, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_size) self.attention = attention_model(2 * num_hiddens, attention_size) self.rnn = nn.GRU(num_hiddens + embed_size, num_hiddens, num_layers, dropout=drop_prob) self.out = nn.Linear(num_hiddens, vocab_size) def forward(self, cur_input, state, enc_states): c = attention_forward(self.attention, enc_states, state[-1]) input_and_c = torch.cat((self.embedding(cur_input), c), dim=1) input_and_c = input_and_c.unsqueeze(0) output, state = self.rnn(input_and_c, state) output = self.out(output).squeeze(0) return output, state def begin_state(self, enc_state): return enc_state
注意力机制
注意力机制是实现即时对话翻译的关键。以下是注意力层的实现代码:
def attention_model(input_size, attention_size): model = nn.Sequential(nn.Linear(input_size, attention_size, bias=False), nn.Tanh(), nn.Linear(attention_size, 1, bias=False)) return modeldef attention_forward(model, enc_states, dec_state): dec_states = dec_state.unsqueeze(0).expand_as(enc_states) enc_and_dec_states = torch.cat((enc_states, dec_states), dim=2) e = model(enc_and_dec_states) alpha = F.softmax(e, dim=0) return (alpha * enc_states).sum(0)
模型训练与预测
训练过程包括数据加载、模型优化以及损失计算。
def batch_loss(encoder, decoder, X, Y, loss): batch_size = X.shape[0] enc_state = encoder.begin_state() enc_outputs, enc_state = encoder(X, enc_state) dec_state = decoder.begin_state(enc_state) dec_input = torch.tensor([out_vocab.stoi[BOS]] * batch_size) mask = torch.ones(batch_size) l = torch.tensor([0.0]) for y in Y.permute(1, 0): dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs) l += (mask * F.cross_entropy(dec_output, y)).sum() dec_input = y mask = mask * (y != out_vocab.stoi[EOS]).float() return l / (mask.sum().float())
代码总结
以上代码实现了一个基于注意力机制的即时对话翻译模型,主要包括以下几个部分:
模型性能评价
为了评估模型的翻译质量,使用BLEU(Bilingual Evaluation Understudy)评分机制。以下是代码实现及其应用示例:
def bleu(pred_tokens, label_tokens, k): len_pred, len_label = len(pred_tokens), len(label_tokens) score = math.exp(min(0, 1 - len_label / len_pred)) for n in range(1, k + 1): e = 0 if len_pred < n or len_label < n: continue num_matches = 0 for i in range(len_pred - n + 1): if label_tokens[i:i+n] == pred_tokens[i:i+n]: num_matches += 1 score *= math.pow(num_matches / max(len_pred, len_label), 0.5) return scoredef score(input_seq, label_seq, k): pred_tokens = translate(encoder, decoder, input_seq, max_seq_len) label_tokens = label_seq.split(' ') print(f'BLEU {bleu(pred_tokens, label_tokens, k):.3f}, predict: {"".join(pred_tokens)}')
实验结果
通过实验验证,模型在小-scale对话数据集上的翻译效果良好。以下是部分示例:
输入句子 | 输出翻译 | BLEU分数 |
---|---|---|
ils regardent . | ils watching . | 0.625 |
ils sont canadienne . | ils are canadian . | 0.67 |
模型表现出较强的语义保留能力和翻译鲁棒性,适合用于对话场景中的即时翻译任务。这一实现为后续的对话机器翻译模型开发提供了可靠的基础。
发表评论
最新留言
路过按个爪印,很不错,赞一个!
[***.219.124.196]2025年05月01日 11时56分36秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
Java基础学习总结(5)——多态
2025-04-02
Java基础学习总结(63)——Java集合总结
2025-04-02
Java基础学习总结(64)——Java内存管理
2025-04-02
Java基础学习总结(66)——配置管理库typesafe.config教程
2025-04-02
Java基础学习总结(67)——Java接口API中使用数组的缺陷
2025-04-02
Java基础学习总结(70)——开发Java项目常用的工具汇总
2025-04-02
Java基础学习总结(73)——Java最新面试题汇总
2025-04-02
Java基础学习总结(75)——Java反射机制及应用场景
2025-04-02
Java基础学习总结(76)——Java异常深入学习研究
2025-04-02
Kubernetes 笔记 08 Deployment 副本管理 重新招一个员工来填坑
2025-04-03
Java基础知识陷阱系列
2025-04-03
Java基础系列
2025-04-03
Kubernetes 自定义服务的启动顺序
2025-04-03
Java基础:Character 类概念、构造函数、实例方法、类方法
2025-04-03
Kubernetes 资源调度详解
2025-04-03
Java基础:StringBuffer类概念、构造函数、常用方法
2025-04-03
Kubernetes 部署 kubeflow1.7.0
2025-04-03
Java基础:变量(声明、赋值、引用)、基本数据类型、作用域
2025-04-03
Kubernetes 部署SonarQube
2025-04-03