
《动手学深度学习》(PyTorch版)代码注释 - 53 【Find_synonyms&analogies】
发布日期:2021-05-19 18:03:21
浏览次数:25
分类:精选文章
本文共 2422 字,大约阅读时间需要 8 分钟。
项目说明
本文的代码以开源项目为基础,结合个人学习理解的注释优化而成,方便读者快速理解各个功能的实现原理和实用场景。
开发环境
使用环境:Python 3.8
平台:Windows 10
开发工具:PyCharm
主要功能说明
本节主要实现二者关系表示的两种任务:近义词和类比词搜索。该功能基于文本向量模型 GloVe,适合用于自然语言处理中的词语关联分析。篇幅较简,代码注释相对较少,但仍保留了核心实现逻辑。
GloVe 数据集推荐使用官方提供的预训练包,建议下载后直接使用即可。
核心代码示例
以下是实现近义词和类比词搜索的关键代码段
# 从 matplotlib 导入绘图库import matplotlib.pyplot as plt# 导入 PyTorch 和文本词库模块import torchfrom torchtext.vocab import vocab# 打印预训练词向量的支持 aliasesprint(vocab.pretrained_aliases.keys())print([key for key in vocab.pretrained_aliases.keys() if "glove" in key])# 定义 GloVe 模型参数cache_dir = "D:/Program/Pytorch/Datasets/glove"glove = vocab.GloVe(name='6B', dim=50, cache=cache_dir)# 打印词向量的基本信息print("一共包含 {} 个词。".format(len(glove.stoi)))print("词向量示例:{} 对应({})".format(glove.stoi['beautiful'], globe.itos[3366]))# 计算余弦相似度的近邻居函数def knn(W, x, k): # 添加数值稳定常数 cos = torch.matmul(W, x.view((-1,))) / ( (torch.sum(W * W, dim=1) + 1e-9).sqrt() * torch.sum(x * x).sqrt() ) _, topk = torch.topk(cos, k=k) return topk, [cos[i].item() for i in topk]# 搜索与输入词最相似的 k 个词def get_similar_tokens(query_token, k, embed): topk, cos = knn(embed.vectors, embed.vectors[embed.stoi[query_token]], k+1) # 打印最相似词及其余弦值 for i, c in zip(topk[1:], cos[1:]): print(f"cosine sim {c:.3f}: {embed.itos[i]}")# 使用示例get_similar_tokens('chip', 3, glove) get_similar_tokens('baby', 3, glove) get_similar_tokens('beautiful', 3, glove) # 搜索类比关系词def get_analogy(token_a, token_b, token_c, embed): vecs = [embed.vectors[embed.stoi[t]] for t in [token_a, token_b, token_c]] x = vecs[1] - vecs[0] + vecs[2] topk, cos = knn(embed.vectors, x, 3) # 打印找到的类比词及其余弦值 for i, c in zip(topk[:], cos[:]): print(f"origin world = {token_c}, cosine sim {c:.3f}: {embed.itos[i]}") return embed.itos[topk[0]]# 使用示例print(get_analogy('man', 'woman', 'son', glove)) # 输出:daughter print(get_analogy('beijing', 'china', 'tokyo', glove)) # 输出:japan print(get_analogy('bad', 'worst', 'big', glove)) # 输出:biggest print(get_analogy('do', 'did', 'go', glove)) # 输出:went
代码注释解读
1. **GloVe 模型加载**
- 使用已预训练的 50 维词向量模型 - 字典路径来自指定缓存目录2. **余弦相似度计算函数**
- `knk(W, x, k)` 是核心计算函数 - W 是词矩阵,x 是查询词向量 - 加入 `1e-9` 来防止数值错误 - 使用 PyTorch 的 `topk` 计算最接近的 k 个词3. **搜索最相似词**
- `get_similar_tokens` 调用 `knk` 求得最接近的 k 个词 - 逐个输出词语及其余弦值 - 示例搜索 'chip'、'baby' 和 'beautiful' 的近义词4. **类比关系搜索**
- `get_analogy` 计算词语关系 - 输入三词,计算并输出最适合的类比词 - 示例验证了多个实际应用场景发表评论
最新留言
不错!
[***.144.177.141]2025年05月06日 23时19分35秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
Java基础:Character 类概念、构造函数、实例方法、类方法
2023-01-29
Kubernetes 资源调度详解
2023-01-29
Java基础:Java 的工作原理和 Java 开发环境
2023-01-29
Java基础:StringBuffer类概念、构造函数、常用方法
2023-01-29
Kubernetes 部署 kubeflow1.7.0
2023-01-29
Java基础:变量(声明、赋值、引用)、基本数据类型、作用域
2023-01-29
Kubernetes 部署SonarQube
2023-01-29
Java基础:如何编写并执行入门级别程序 Hello World
2023-01-29
Java基础:循环语句for、while和do-while
2023-01-29
kubernetes 部署SonarQube 7.1 关联LDAP
2023-01-29
Java基础:按位运算符
2023-01-29
Kubernetes 配置管理实战
2023-01-29
Java基础:数字类概念、常用方法、常量
2023-01-29
Kubernetes 针对资源紧缺处理方式的配置
2023-01-29
Java基础:数组创建、初始化、引用、分类
2023-01-29
Java基础:数组的长度、数组的复制
2023-01-29
Kubernetes 问题总结
2023-01-29
Java基础:条件运算符
2023-01-29
Kubernetes 集成Traefik(一)—— 转发鉴权
2023-01-29
Java基础:比较运算符
2023-01-29