tensorflow系列——读取tfrecord数据
发布日期:2021-09-30 09:33:45
浏览次数:1
分类:技术文章
本文共 8078 字,大约阅读时间需要 26 分钟。
-----------TensorFlow1.x-----------
方式汇总:
- tf.data.experimental.make_batched_features_dataset
- tf.parse_single_example
- tf.parse_example
- example.ParseFromString
- tf.parse_single_sequence_example
- tf.parse_sequence_example
注意:tf也可以写为tf.io.***
一、为现成的estimator创建TrainSpec
用于
# 模型:model = tf.estimator.LinearClassifier( feature_columns=get_feature_columns(""), model_dir=FLAGS.model_dir, n_classes=2, optimizer=tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate, ), config=None, warm_start_from=None, sparse_combiner='sum')# 训练方式:estimator_lib.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)# 或tf.estimator.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)
1、获取tfrecord文件目录
for path in path_list: file_list.extend(tf.io.gfile.glob(path))
2、直接解析tfrecord数据
feature_schema = { "user_data": tf.io.FixedLenFeature(shape=(43,),dtype=tf.float32), # "label": tf.io.FixedLenFeature(shape=(1,),dtype=tf.float32)}# 注意:将所有浮点特征数据以及label都放到了user_data中,所以label_key设置为NonedataTestTrain = tf.data.experimental.make_batched_features_dataset( file_pattern=file_list, batch_size=FLAGS.train_batch_size, features=feature_schema, label_key=None, num_epochs=FLAGS.train_epochs, shuffle=True, shuffle_buffer_size=FLAGS.train_shuffle_buffer, shuffle_seed=random.randint(0, 1000000), reader_num_threads=FLAGS.reader_num_threads, parser_num_threads=FLAGS.parser_num_threads, drop_final_batch=True)
3、处理解析的tfrecord数据
# 说明:将各个特征(浮点型/类别整型)及标签label数据从user_data中拆解出来,并做相应转换处理。def parse_exmp_batched(serial_exmp): oriAllData = serial_exmp.get("user_data") feaDics=dict() retainLabel = oriAllData[:,0:1] feaDics["sta_fea1"]=oriAllData[:,1:18] feaDics["click_level"]=tf.cast(oriAllData[:,18:19],dtype=tf.int64) return feaDics,tf.identity(retainLabel,"label")return dataTestTrain.map(parse_exmp_batched,num_parallel_calls=8)
4、若想直接使用解析到的tfrecord数据
feature_schema = { # user features "sex": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64), "age": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64), "label": tf.io.FixedLenFeature(shape=(1,), dtype=tf.float32)}return tf.data.experimental.make_batched_features_dataset( file_pattern=eval_files, batch_size=FLAGS.eval_batch_size, features=feature_schema, label_key="label", num_epochs=10, shuffle=False, shuffle_buffer_size=FLAGS.eval_batch_size, reader_num_threads=FLAGS.reader_num_threads, parser_num_threads=FLAGS.parser_num_threads, drop_final_batch=False)
二、为自定义estimator创建TrainSpec
用于:
# 使用自定义estimator创建模型dnn_model = MyEstimator( model_dir=FLAGS.model_dir, optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate), hidden_units = list(map(lambda x:int(x),FLAGS.hidden.split(","))), activation_fn = tf.nn.relu, dropout = FLAGS.dropout, batch_norm = True, weight_column = None, label_vocabulary=None, loss_reduction = tf.losses.Reduction.SUM_OVER_BATCH_SIZE, params=None, # config=config, warm_start_from=None )# 训练方式:estimator_lib.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)或tf.estimator.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)
1、获取tfrecord文件目录
# path_list为所有的tf文件的完整路径for path in path_list: file_list.extend(tf.io.gfile.glob(path))input_files = tf.data.Dataset.list_files(file_list)
2、获取原始的tfrecord数据
dataset = input_files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,cycle_length=FLAGS.reader_num_threads))
3、解析获取的tfrecord数据
def parse_exmp(serial_exmp): # 可以依据feature_column生成'user_data' # feature_spec = tf.feature_column.make_parse_example_spec(feature_column) oriExample = tf.parse_single_example(serial_exmp,features={'user_data':tf.FixedLenFeature([43], tf.float32)}) oriAllData = oriExample.get("user_data") feaDics=dict() retainLabel = oriAllData[0:1] feaDics["sta_fea1"]=oriAllData[1:18] feaDics["click_level"]=tf.cast(oriAllData[18:19],dtype=tf.int64) return feaDics,{"label":tf.to_float(retainLabel)}dataset = dataset.map(parse_exmp,num_parallel_calls=8)dataset = dataset.repeat().batch(FLAGS.train_batch_size).prefetch(1)return dataset
三、本地调试打印tfrecord数据
1、使用make_one_shot_iterator获取tfrecord数据并用make_batched_features_dataset方式解析数据并用session-run方式打印
feature_schema = { "user_data": tf.io.FixedLenFeature(shape=(43,),dtype=tf.float32)}def parse_exmp_batched(serial_exmp): oriAllData = serial_exmp.get("user_data") feaDics=dict() retainLabel = oriAllData[0:1] feaDics["sta_fea1"]=oriAllData[1:18] feaDics["click_level"]=tf.cast(oriAllData[18:19],dtype=tf.int64) return feaDics,{"label":tf.to_float(retainLabel)}def train_input_fn(): return tf.data.experimental.make_batched_features_dataset( file_pattern=train_files, batch_size=10, features=feature_schema, label_key=None, num_epochs=5, shuffle=True, shuffle_buffer_size=2000, shuffle_seed=random.randint(0, 1000000), reader_num_threads=4, parser_num_threads=4, drop_final_batch=True)dataTest = train_input_fn()dataset = dataTest.map(parse_exmp_batched,num_parallel_calls=8)test_op = dataset.make_one_shot_iterator()one_element = test_op.get_next()with tf.Session() as sess: for i in range(1): print(sess.run(one_element)) # print(sess.run(one_element['user_data'])) print(sess.run([one_element[1]['label'],one_element[0]['sta_fea1']]))
2、使用tf.data.TFRecordDataset获取tfrecord数据并用parse_single_example解析用session-run打印
def parse_exmp(serial_exmp): feature_spec = tf.feature_column.make_parse_example_spec(feature_column) oriExample = tf.parse_single_example(serial_exmp,features={'user_data':tf.FixedLenFeature([43], tf.float32)}) oriAllData = oriExample.get("user_data") feaDics=dict() retainLabel = oriAllData[0:1] feaDics["sta_fea1"]=oriAllData[1:18] feaDics["actDay_fea1"]=oriAllData[20:33] feaDics["act_first_fea1"]=tf.cast(oriAllData[33:34],dtype=tf.int64) return feaDics,{"label":tf.to_float(retainLabel)}train_files = [...]input_files = tf.data.Dataset.list_files(train_files)dataset = input_files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,cycle_length=reader_num_threads))dataset = dataset.map(parse_exmp,num_parallel_calls=8)element = dataset.make_one_shot_iterator()for i in range(0): try: tmp = element.get_next()[0]['actDay_fea1'] print(sess.run(tmp)) check_a=tf.check_numerics(tmp,"non number"+str(sess.run(tmp))) float_val = sess.run(check_a) print(len(float_val),float_val) except ZeroDivisionError as e: print("inf值") finally: print(str(len(float_val)),str(float_val))
4、使用python_io.tf_record_iterator方式获取tfrecord数据并用ParseFromString解析并用print打印
# tf_records_filenames = "..."for record in tf.python_io.tf_record_iterator(tf_records_filenames): example = tf.train.Example() example.ParseFromString(record) ltv4v = example.features.feature['ltv4'].int64_list.value print(ltv4v)
5、使用python_io.tf_record_iterator方式获取tfrecord数据并用parse_single_example解析并使用session-run打印
注意:此方式已经废弃
sess =tf.Session()for record in tf.python_io.tf_record_iterator(tf_records_filenames): # 根据key名字得到保存的features字典 features = tf.parse_single_example(record, features={ "label":tf.FixedLenFeature([], tf.float32), "ltv4":tf.FixedLenFeature([],tf.int64), #"game_id":tf.FixedLenFeature([],tf.string), }) float_val = sess.run(features['ltv4'])
-----------TensorFlow2.x-----------
四、本地打印tfrecord数据
参考:
转载地址:https://blog.csdn.net/h_jlwg6688/article/details/116663346 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
关注你微信了!
[***.104.42.241]2024年04月20日 15时32分23秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
Android软键盘(1)---输入法界面管理(打开/关闭/状态获取)
2019-04-27
Android动态设置view的高度宽度
2019-04-27
css3 属性 text-overflow 实现截取多余文字内容 以省略号来代替多余内容
2019-04-27
vue 事件总线EventBus的概念、使用以及注意点
2019-04-27
JavaScript 用七种方式教你判断一个变量是否为数组类型
2019-04-27
黄家懿:河北高校邀请赛 -- 二手车交易价格预测决赛答辩
2019-04-27
如何利用pyecharts绘制酷炫的桑基图?
2019-04-27
王朝阳:河北高校邀请赛 -- 二手车交易价格预测决赛答辩
2019-04-27
Scratch等级考试(二级)模拟题
2019-04-27
如何在Jupyter Lab中显示pyecharts的图形?
2019-04-27
什么是Python之禅?
2019-04-27
【青少年编程】【Scratch】01 运动模块
2019-04-27
json的序列化与反序列化
2019-04-27
【第16周复盘】学习的飞轮
2019-04-27
如何利用pyecharts绘制炫酷的关系网络图?
2019-04-27
NCEPU:线下组队学习周报(007)
2019-04-27
【青少年编程】【二级】寻找宝石
2019-04-27
【组队学习】【26期】Linux教程
2019-04-27