
Tensorflow端到端车牌/验证码识别中的训练代码
发布日期:2021-05-07 05:53:19
浏览次数:20
分类:精选文章
本文共 5412 字,大约阅读时间需要 18 分钟。
#coding=utf-8import syssys.path.insert(0, "../../python")import mxnet as mximport numpy as npimport cv2, randomfrom io import BytesIOfrom genplate import *class OCRBatch(object): def __init__(self, data_names, data, label_names, label): self.data = data self.label = label self.data_names = data_names self.label_names = label_names @property def provide_data(self): return [(n, x.shape) for n, x in zip(self.data_names, self.data)] @property def provide_label(self): return [(n, x.shape) for n, x in zip(self.label_names, self.label)]def rand_range(lo,hi): return lo+r(hi-lo);def gen_rand(): name = ""; label= []; label.append(rand_range(0,31)); label.append(rand_range(41,65)); for i in range(5): label.append(rand_range(31,65)) name+=chars[label[0]] name+=chars[label[1]] for i in range(5): name+=chars[label[i+2]] return name,labeldef gen_sample(genplate, width, height): num,label = gen_rand() img = genplate.generate(num) img = cv2.resize(img, (width, height)) img = np.multiply(img, 1/255.0) img = img.transpose(2, 0, 1) return label, imgclass OCRIter(mx.io.DataIter): def __init__(self, count, batch_size, num_label, height, width): super(OCRIter, self).__init__() self.genplate = GenPlate("./font/platech.ttf",'./font/platechar.ttf','./NoPlates') self.batch_size = batch_size self.count = count self.height = height self.width = width self.provide_data = [('data', (batch_size, 3, height, width))] self.provide_label = [('softmax_label', (self.batch_size, num_label))] print ("start") def __iter__(self): for k in range(int(self.count / self.batch_size)): data = [] label = [] for i in range(self.batch_size): num, img = gen_sample(self.genplate, self.width, self.height) data.append(img) label.append(num) data_all = [mx.nd.array(data)] label_all = [mx.nd.array(label)] data_names = ['data'] label_names = ['softmax_label'] data_batch = OCRBatch(data_names, data_all, label_names, label_all) yield data_batch def reset(self): passdef get_ocrnet(): data = mx.symbol.Variable('data') label = mx.symbol.Variable('softmax_label') conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32) pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1)) relu1 = mx.symbol.Activation(data=pool1, act_type="relu") conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32) pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1)) relu2 = mx.symbol.Activation(data=pool2, act_type="relu") # conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32) # pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1)) # relu3 = mx.symbol.Activation(data=pool3, act_type="relu") # # conv4 = mx.symbol.Convolution(data=relu3, kernel=(3,3), num_filter=32) # pool4 = mx.symbol.Pooling(data=conv4, pool_type="avg", kernel=(2,2), stride=(1, 1)) # relu4 = mx.symbol.Activation(data=pool4, act_type="relu") flatten = mx.symbol.Flatten(data = relu2) fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 120) fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65) fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65) fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65) fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65) fc25 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65) fc26 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65) fc27 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65) fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24,fc25,fc26,fc27], dim = 0) label = mx.symbol.transpose(data = label) label = mx.symbol.Reshape(data = label, target_shape = (0, )) return mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")def Accuracy(label, pred): label = label.T.reshape((-1, )) hit = 0 total = 0 for i in range(int(pred.shape[0] / 7)): ok = True for j in range(7): k = i * 7 + j if np.argmax(pred[k]) != int(label[k]): ok = False break if ok: hit += 1 total += 1 return 1.0 * hit / totaldef train(): network = get_ocrnet() devs = [mx.gpu(i) for i in range(1)] model = mx.model.FeedForward( symbol = network, num_epoch = 1, learning_rate = 0.001, wd = 0.00001, initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), momentum = 0.9) batch_size = 100 data_train = OCRIter(100000, batch_size, 7, 30, 120) data_test = OCRIter(1000, batch_size,7, 30, 120) import logging head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) model.fit(X = data_train, eval_data = data_test, eval_metric = Accuracy, batch_end_callback=mx.callback.Speedometer(batch_size, 100)) model.save("cnn-ocr") print (gen_rand())if __name__ == '__main__': train();
发表评论
最新留言
网站不错 人气很旺了 加油
[***.192.178.218]2025年04月16日 22时16分34秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
MySQL用户管理:添加用户、授权、删除用户
2019-03-06
比技术还重要的事
2019-03-06
linux线程调度策略
2019-03-06
软中断和实时性
2019-03-06
Linux探测工具BCC(可观测性)
2019-03-06
流量控制--2.传统的流量控制元素
2019-03-06
SNMP介绍及使用,超有用,建议收藏!
2019-03-06
51nod 1596 搬货物(二进制处理)
2019-03-06
来自星星的祝福(容斥+排列组合)
2019-03-06
Hmz 的女装(递推)
2019-03-06
HDU5589:Tree(莫队+01字典树)
2019-03-06
不停机替换线上代码? 你没听错,Arthas它能做到
2019-03-06
sharding-jdbc 分库分表的 4种分片策略,还蛮简单的
2019-03-06
分库分表的 9种分布式主键ID 生成方案,挺全乎的
2019-03-06
MySQL不会丢失数据的秘密,就藏在它的 7种日志里
2019-03-06
Python开发之序列化与反序列化:pickle、json模块使用详解
2019-03-06
回顾-生成 vs 判别模型-和图
2019-03-06
采坑 - 字符串的 "" 与 pd.isnull()
2019-03-06
无序列表 - 链表
2019-03-06
SQL 查询强化 - 数据准备
2019-03-06