
center loss代码PyTorch
发布日期:2021-05-06 11:08:18
浏览次数:63
分类:精选文章
本文共 779 字,大约阅读时间需要 2 分钟。
CenterLoss 实现解析
代码解析
关键代码行
distmat.addmm_(1, -2, x, centers.t())
这行代码相当于:
\text{distmat} = \text{distmat} - 2 \times x \times \text{centers}^T
其中,distmat
是先前计算的两组特征向量的平方和矩阵:
\text{distmat} = (\|x\|_2^2) \otimes 1 + (\|c\|_2^2) \otimes 1
这一步通过平方差公式展开计算特征向量之间的距离矩阵。
mask 计算
mask = labels.eq(classes.expand(batch_size, num_classes))
torch.eq()
运算符用于比较 labels
和 classes
,生成一个布尔矩阵 mask
,其中 True
表示特征向量属于相同类别。
损失计算
dist = []for i in range(batch_size): value = distmat[i][mask[i]] value = value.clamp(min=1e-12, max=1e+12) # 数值稳定性处理 dist.append(value)dist = torch.cat(dist)loss = dist.mean()
将所有类别内的距离取平均值作为损失函数。
示例说明
假设:
batch_size = 5
num_classes = 6
x = torch.rand(5, 10)
(随机特征向量)targets = torch.Tensor([0, 1, 2, 3, 2]).long()
(标签)
CenterLoss
会计算每个样本到类中心的欧氏距离,并通过平均损失函数优化模型参数。
发表评论
最新留言
表示我来过!
[***.240.166.169]2025年04月14日 11时26分43秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
【Maven】POM基本概念
2019-03-06
【Java思考】Java 中的实参与形参之间的传递到底是值传递还是引用传递呢?
2019-03-06
【设计模式】单例模式
2019-03-06
【Linux】2.3 Linux目录结构
2019-03-06
远程触发Jenkins的Pipeline任务的并发问题处理
2019-03-06
Web应用程序并发问题处理的一点小经验
2019-03-06
entity framework core在独立类库下执行迁移操作
2019-03-06
Asp.Net Core 2.1+的视图缓存(响应缓存)
2019-03-06
RE套路 - 关于pyinstaller打包文件的复原
2019-03-06
【wp】HWS计划2021硬件安全冬令营线上选拔赛
2019-03-06
Ef+T4模板实现代码快速生成器
2019-03-06
c++ static笔记
2019-03-06
C++中头文件相互包含与前置声明
2019-03-06
JQuery选择器
2019-03-06
多线程之volatile关键字
2019-03-06
2.2.2原码补码移码的作用
2019-03-06
Java面试题:Servlet是线程安全的吗?
2019-03-06
Java集合总结系列2:Collection接口
2019-03-06
Linux学习总结(九)—— CentOS常用软件安装:中文输入法、Chrome
2019-03-06
MySQL用户管理:添加用户、授权、删除用户
2019-03-06