
mxnet METRIC自定义评估验证函数
发布日期:2021-05-07 16:55:25
浏览次数:23
分类:技术文章
本文共 3520 字,大约阅读时间需要 11 分钟。
insightface自定义loss:
params = [1.e-10] sel = mx.symbol.argmax(data = fc7, axis=1) sel = (sel==gt_label) norm = embedding*embedding norm = mx.symbol.sum(norm, axis=1) norm = norm+params[0] feature_incay = sel/norm feature_incay = mx.symbol.mean(feature_incay) * args.incay extra_loss = mx.symbol.MakeLoss(feature_incay)
自定义损失函数
# -*- coding=utf-8 -*-import mxnet as mximport numpy as npimport logginglogging.basicConfig(level=logging.INFO)x = mx.sym.Variable('data')y = mx.sym.FullyConnected(data=x, num_hidden=1)label = mx.sym.Variable('label')cross_entropy = label * log(out) + (1 - label) * log(1 - out)loss = MakeLoss(cross_entropy)pred_loss = mx.sym.Group([mx.sym.BlockGrad(y), loss])ex = pred_loss.simple_bind(mx.cpu(), data=(32, 2))# testtest_data = mx.nd.array(np.random.random(size=(32, 2)))test_label = mx.nd.array(np.random.random(size=(32, 1)))ex.forward(is_train=True, data=test_data, label=test_label)ex.backward()print ex.arg_dictfc_w = ex.arg_dict['fullyconnected0_weight'].asnumpy()fc_w_grad = ex.grad_arrays[1].asnumpy()fc_bias = ex.arg_dict['fullyconnected0_bias'].asnumpy()fc_bias_grad = ex.grad_arrays[2].asnumpy()logging.info('fc_weight:{}, fc_weights_grad:{}'.format(fc_w, fc_w_grad))logging.info('fc_bias:{}, fc_bias_grad:{}'.format(fc_bias, fc_bias_grad))
使用makeloss只能得到损失而不是预测,要得到损失和预测需要使用mx.sym.Group()和mx.sym.BlockGrad()
label = mx.sym.Variable('label')out = mx.sym.Activation(data=final, act_type='sigmoid')ce = label * mx.sym.log(out) + (1 - label) * mx.sym.log(1 - out)weights = mx.sym.Variable('weights')loss = mx.sym.MakeLoss(weigths * ce, normalization='batch')
# -*- coding=utf-8 -*-import mxnet as mximport numpy as npimport logginglogging.basicConfig(level=logging.INFO)x = mx.sym.Variable('data')y = mx.sym.FullyConnected(data=x, num_hidden=1)label = mx.sym.Variable('label')cross_entropy = label * log(out) + (1 - label) * log(1 - out)loss = MakeLoss(cross_entropy)pred_loss = mx.sym.Group([mx.sym.BlockGrad(y), loss])ex = pred_loss.simple_bind(mx.cpu(), data=(32, 2))# testtest_data = mx.nd.array(np.random.random(size=(32, 2)))test_label = mx.nd.array(np.random.random(size=(32, 1)))ex.forward(is_train=True, data=test_data, label=test_label)ex.backward()print ex.arg_dictfc_w = ex.arg_dict['fullyconnected0_weight'].asnumpy()fc_w_grad = ex.grad_arrays[1].asnumpy()fc_bias = ex.arg_dict['fullyconnected0_bias'].asnumpy()fc_bias_grad = ex.grad_arrays[2].asnumpy()logging.info('fc_weight:{}, fc_weights_grad:{}'.format(fc_w, fc_w_grad))logging.info('fc_bias:{}, fc_bias_grad:{}'.format(fc_bias, fc_bias_grad))
import mxnet as mxclass Siamise_metric(mx.metric.EvalMetric): def __init__(self, name='siamise_acc'): super(Siamise_metric, self).__init__(name=name) def update(self, label, pred): preds = pred[0] labels = label[0] preds_label = preds.asnumpy().ravel() labels = labels.asnumpy().ravel() #self.sum_metric += labels[preds_label < 0.5].sum() + len( # labels[preds_label >= 0.5]) - labels[preds_label >= 0.5].sum() #self.num_inst += len(labels) pred = (preds_label < 0.5) acc = (pred == labels).sum() self.sum_metric += acc self.num_inst += len(labels) # numpy.prod(label.shape)class Contrastive_loss(mx.metric.EvalMetric): def __init__(self, name='contrastive_loss'): super(Contrastive_loss, self).__init__(name=name) def update(self, label, pred): loss = pred[1].asnumpy() self.sum_metric += loss self.num_inst += len(loss)
发表评论
最新留言
关注你微信了!
[***.104.42.241]2025年04月02日 00时47分47秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
理解Python系统下的时间格式
2019-03-04
Python语言'类'概念再理解
2019-03-04
OpenAI Gym简介及初级实例
2019-03-04
Ubuntu 18.04 zip压缩文件及其文件 夹中的所以 内容
2019-03-04
int 转 CString
2019-03-04
Edit编辑框自动换行与长度
2019-03-04
低通滤波器的设计
2019-03-04
窄带随机过程的产生
2019-03-04
随机四则运算
2019-03-04
Java面向对象
2019-03-04
JAVA带标签的break和continue
2019-03-04
Java获取线程基本信息的方法
2019-03-04
Java集合Collection
2019-03-04
SpringBoot快速入门
2019-03-04
医疗管理系统-手机快速登录和SpringSecurity权限控制
2019-03-04
vue源码分析(MVVM篇)
2019-03-04
React(八)- ReactUI组件库及Redux的使用
2019-03-04
TypeScript系列文章导航
2019-03-04
base64编码字符串和图片的互转
2019-03-04
汉字转为拼音
2019-03-04