tensorflow系列——读取tfrecord数据
发布日期:2021-09-30 09:33:45 浏览次数:1 分类:技术文章

本文共 8078 字,大约阅读时间需要 26 分钟。

-----------TensorFlow1.x-----------

方式汇总:

  • tf.data.experimental.make_batched_features_dataset
  • tf.parse_single_example
  • tf.parse_example
  • example.ParseFromString
  • tf.parse_single_sequence_example
  • tf.parse_sequence_example

注意:tf也可以写为tf.io.***

一、为现成的estimator创建TrainSpec

用于

# 模型:model = tf.estimator.LinearClassifier(        feature_columns=get_feature_columns(""),        model_dir=FLAGS.model_dir,        n_classes=2,        optimizer=tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate,                     ),        config=None,        warm_start_from=None,        sparse_combiner='sum')# 训练方式:estimator_lib.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)# 或tf.estimator.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)

1、获取tfrecord文件目录

for path in path_list:        file_list.extend(tf.io.gfile.glob(path))

2、直接解析tfrecord数据

feature_schema = {    "user_data": tf.io.FixedLenFeature(shape=(43,),dtype=tf.float32),    # "label": tf.io.FixedLenFeature(shape=(1,),dtype=tf.float32)}# 注意:将所有浮点特征数据以及label都放到了user_data中,所以label_key设置为NonedataTestTrain =  tf.data.experimental.make_batched_features_dataset(            file_pattern=file_list,            batch_size=FLAGS.train_batch_size,            features=feature_schema,            label_key=None,            num_epochs=FLAGS.train_epochs,            shuffle=True,            shuffle_buffer_size=FLAGS.train_shuffle_buffer,            shuffle_seed=random.randint(0, 1000000),            reader_num_threads=FLAGS.reader_num_threads,            parser_num_threads=FLAGS.parser_num_threads,            drop_final_batch=True)

 

3、处理解析的tfrecord数据

# 说明:将各个特征(浮点型/类别整型)及标签label数据从user_data中拆解出来,并做相应转换处理。def parse_exmp_batched(serial_exmp):        oriAllData = serial_exmp.get("user_data")        feaDics=dict()        retainLabel = oriAllData[:,0:1]        feaDics["sta_fea1"]=oriAllData[:,1:18]        feaDics["click_level"]=tf.cast(oriAllData[:,18:19],dtype=tf.int64)        return feaDics,tf.identity(retainLabel,"label")return dataTestTrain.map(parse_exmp_batched,num_parallel_calls=8)

4、若想直接使用解析到的tfrecord数据

feature_schema = {    # user features    "sex": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64),    "age": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64),    "label": tf.io.FixedLenFeature(shape=(1,), dtype=tf.float32)}return tf.data.experimental.make_batched_features_dataset(            file_pattern=eval_files,            batch_size=FLAGS.eval_batch_size,            features=feature_schema,            label_key="label",            num_epochs=10,            shuffle=False,            shuffle_buffer_size=FLAGS.eval_batch_size,            reader_num_threads=FLAGS.reader_num_threads,            parser_num_threads=FLAGS.parser_num_threads,            drop_final_batch=False)

二、为自定义estimator创建TrainSpec

用于:

# 使用自定义estimator创建模型dnn_model =  MyEstimator(        model_dir=FLAGS.model_dir,        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate),        hidden_units = list(map(lambda x:int(x),FLAGS.hidden.split(","))),        activation_fn = tf.nn.relu,        dropout = FLAGS.dropout,        batch_norm = True,        weight_column = None,        label_vocabulary=None,        loss_reduction = tf.losses.Reduction.SUM_OVER_BATCH_SIZE,        params=None,        # config=config,        warm_start_from=None    )# 训练方式:estimator_lib.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)或tf.estimator.train_and_evaluate(estimator=model,train_spec=train_spec,eval_spec=eval_spec)

1、获取tfrecord文件目录

# path_list为所有的tf文件的完整路径for path in path_list:        file_list.extend(tf.io.gfile.glob(path))input_files = tf.data.Dataset.list_files(file_list)

2、获取原始的tfrecord数据

dataset = input_files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,cycle_length=FLAGS.reader_num_threads))

3、解析获取的tfrecord数据

def parse_exmp(serial_exmp):        # 可以依据feature_column生成'user_data'        # feature_spec = tf.feature_column.make_parse_example_spec(feature_column)        oriExample = tf.parse_single_example(serial_exmp,features={'user_data':tf.FixedLenFeature([43], tf.float32)})        oriAllData = oriExample.get("user_data")        feaDics=dict()        retainLabel = oriAllData[0:1]        feaDics["sta_fea1"]=oriAllData[1:18]        feaDics["click_level"]=tf.cast(oriAllData[18:19],dtype=tf.int64)        return feaDics,{"label":tf.to_float(retainLabel)}dataset = dataset.map(parse_exmp,num_parallel_calls=8)dataset = dataset.repeat().batch(FLAGS.train_batch_size).prefetch(1)return dataset

三、本地调试打印tfrecord数据

1、使用make_one_shot_iterator获取tfrecord数据并用make_batched_features_dataset方式解析数据并用session-run方式打印

feature_schema = {    "user_data": tf.io.FixedLenFeature(shape=(43,),dtype=tf.float32)}def parse_exmp_batched(serial_exmp):        oriAllData = serial_exmp.get("user_data")        feaDics=dict()        retainLabel = oriAllData[0:1]        feaDics["sta_fea1"]=oriAllData[1:18]        feaDics["click_level"]=tf.cast(oriAllData[18:19],dtype=tf.int64)        return feaDics,{"label":tf.to_float(retainLabel)}def train_input_fn():        return tf.data.experimental.make_batched_features_dataset(            file_pattern=train_files,            batch_size=10,            features=feature_schema,            label_key=None,            num_epochs=5,            shuffle=True,            shuffle_buffer_size=2000,            shuffle_seed=random.randint(0, 1000000),            reader_num_threads=4,            parser_num_threads=4,            drop_final_batch=True)dataTest = train_input_fn()dataset = dataTest.map(parse_exmp_batched,num_parallel_calls=8)test_op = dataset.make_one_shot_iterator()one_element = test_op.get_next()with tf.Session() as sess:    for i in range(1):        print(sess.run(one_element))        # print(sess.run(one_element['user_data']))        print(sess.run([one_element[1]['label'],one_element[0]['sta_fea1']]))

2、使用tf.data.TFRecordDataset获取tfrecord数据并用parse_single_example解析用session-run打印

def parse_exmp(serial_exmp):        feature_spec = tf.feature_column.make_parse_example_spec(feature_column)        oriExample = tf.parse_single_example(serial_exmp,features={'user_data':tf.FixedLenFeature([43], tf.float32)})        oriAllData = oriExample.get("user_data")        feaDics=dict()        retainLabel = oriAllData[0:1]        feaDics["sta_fea1"]=oriAllData[1:18]        feaDics["actDay_fea1"]=oriAllData[20:33]        feaDics["act_first_fea1"]=tf.cast(oriAllData[33:34],dtype=tf.int64)        return feaDics,{"label":tf.to_float(retainLabel)}train_files = [...]input_files = tf.data.Dataset.list_files(train_files)dataset = input_files.apply(tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,cycle_length=reader_num_threads))dataset = dataset.map(parse_exmp,num_parallel_calls=8)element = dataset.make_one_shot_iterator()for i in range(0):    try:        tmp = element.get_next()[0]['actDay_fea1']        print(sess.run(tmp))        check_a=tf.check_numerics(tmp,"non number"+str(sess.run(tmp)))        float_val = sess.run(check_a)        print(len(float_val),float_val)    except ZeroDivisionError as e:        print("inf值")    finally:        print(str(len(float_val)),str(float_val))

 

4、使用python_io.tf_record_iterator方式获取tfrecord数据并用ParseFromString解析并用print打印

# tf_records_filenames = "..."for record in tf.python_io.tf_record_iterator(tf_records_filenames):        example = tf.train.Example()        example.ParseFromString(record)        ltv4v = example.features.feature['ltv4'].int64_list.value        print(ltv4v)

5、使用python_io.tf_record_iterator方式获取tfrecord数据并用parse_single_example解析并使用session-run打印

注意:此方式已经废弃

sess =tf.Session()for record in tf.python_io.tf_record_iterator(tf_records_filenames):        # 根据key名字得到保存的features字典        features = tf.parse_single_example(record,        features={        "label":tf.FixedLenFeature([], tf.float32),            "ltv4":tf.FixedLenFeature([],tf.int64),            #"game_id":tf.FixedLenFeature([],tf.string),        })        float_val = sess.run(features['ltv4'])

-----------TensorFlow2.x-----------

四、本地打印tfrecord数据

参考:

转载地址:https://blog.csdn.net/h_jlwg6688/article/details/116663346 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:TensorFlow系列——在estimator中使用feature_column处理tfrecord特征
下一篇:sql系列——hive之array、map、struct、java函数(udf)、python函数、分隔符、json_tuple的处理

发表评论

最新留言

关注你微信了!
[***.104.42.241]2024年04月20日 15时32分23秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章