Tensorflow端到端车牌/验证码识别中的训练代码
发布日期:2021-05-07 05:53:19 浏览次数:20 分类:精选文章

本文共 5412 字,大约阅读时间需要 18 分钟。

#coding=utf-8import syssys.path.insert(0, "../../python")import mxnet as mximport numpy as npimport cv2, randomfrom io import BytesIOfrom genplate import *class OCRBatch(object):    def __init__(self, data_names, data, label_names, label):        self.data = data        self.label = label        self.data_names = data_names        self.label_names = label_names    @property    def provide_data(self):        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]    @property    def provide_label(self):        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]def rand_range(lo,hi):    return lo+r(hi-lo);def gen_rand():    name = "";    label= [];    label.append(rand_range(0,31));    label.append(rand_range(41,65));    for i in range(5):        label.append(rand_range(31,65))    name+=chars[label[0]]    name+=chars[label[1]]    for i in range(5):        name+=chars[label[i+2]]    return name,labeldef gen_sample(genplate, width, height):    num,label = gen_rand()    img = genplate.generate(num)    img = cv2.resize(img, (width, height))    img = np.multiply(img, 1/255.0)    img = img.transpose(2, 0, 1)    return label, imgclass OCRIter(mx.io.DataIter):    def __init__(self, count, batch_size, num_label, height, width):        super(OCRIter, self).__init__()        self.genplate = GenPlate("./font/platech.ttf",'./font/platechar.ttf','./NoPlates')        self.batch_size = batch_size        self.count = count        self.height = height        self.width = width        self.provide_data = [('data', (batch_size, 3, height, width))]        self.provide_label = [('softmax_label', (self.batch_size, num_label))]        print ("start")    def __iter__(self):        for k in range(int(self.count / self.batch_size)):            data = []            label = []            for i in range(self.batch_size):                num, img = gen_sample(self.genplate, self.width, self.height)                data.append(img)                label.append(num)            data_all = [mx.nd.array(data)]            label_all = [mx.nd.array(label)]            data_names = ['data']            label_names = ['softmax_label']            data_batch = OCRBatch(data_names, data_all, label_names, label_all)            yield data_batch    def reset(self):        passdef get_ocrnet():    data = mx.symbol.Variable('data')    label = mx.symbol.Variable('softmax_label')    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32)    pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1))    relu1 = mx.symbol.Activation(data=pool1, act_type="relu")    conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32)    pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1))    relu2 = mx.symbol.Activation(data=pool2, act_type="relu")    # conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32)    # pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1))    # relu3 = mx.symbol.Activation(data=pool3, act_type="relu")    #    # conv4 = mx.symbol.Convolution(data=relu3, kernel=(3,3), num_filter=32)    # pool4 = mx.symbol.Pooling(data=conv4, pool_type="avg", kernel=(2,2), stride=(1, 1))    # relu4 = mx.symbol.Activation(data=pool4, act_type="relu")    flatten = mx.symbol.Flatten(data = relu2)    fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 120)    fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)    fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)    fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)    fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)    fc25 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)    fc26 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)    fc27 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)    fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24,fc25,fc26,fc27], dim = 0)    label = mx.symbol.transpose(data = label)    label = mx.symbol.Reshape(data = label, target_shape = (0, ))    return mx.symbol.SoftmaxOutput(data = fc2, label = label, name = "softmax")def Accuracy(label, pred):    label = label.T.reshape((-1, ))    hit = 0    total = 0    for i in range(int(pred.shape[0] / 7)):        ok = True        for j in range(7):            k = i * 7 + j            if np.argmax(pred[k]) != int(label[k]):                ok = False                break        if ok:            hit += 1        total += 1    return 1.0 * hit / totaldef train():    network = get_ocrnet()    devs = [mx.gpu(i) for i in range(1)]    model = mx.model.FeedForward(                                 symbol = network,                                 num_epoch = 1,                                 learning_rate = 0.001,                                 wd = 0.00001,                                 initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),                                 momentum = 0.9)    batch_size = 100    data_train = OCRIter(100000, batch_size, 7, 30, 120)    data_test = OCRIter(1000, batch_size,7, 30, 120)    import logging    head = '%(asctime)-15s %(message)s'    logging.basicConfig(level=logging.DEBUG, format=head)    model.fit(X = data_train, eval_data = data_test, eval_metric = Accuracy, batch_end_callback=mx.callback.Speedometer(batch_size, 100))    model.save("cnn-ocr")    print (gen_rand())if __name__ == '__main__':    train();
上一篇:在tensorflow object detection安装出现错误总结
下一篇:Tensorflow端到端车牌/验证码识别中的车牌生成器

发表评论

最新留言

网站不错 人气很旺了 加油
[***.192.178.218]2025年04月16日 22时16分34秒