TensorFlow系列——在estimator中使用feature_column处理tfrecord特征
发布日期:2021-09-30 09:33:46
浏览次数:1
分类:技术文章
本文共 3929 字,大约阅读时间需要 13 分钟。
一、用于现成的estimator模型
1、读取tfrecord数据
1.1、tfrecord中包含所有(特征名-值)以及标签的情况
feature_schema = { # 包含了tfrecord里的所有特征,包括标签label "sex": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64), "age": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64), "label": tf.io.FixedLenFeature(shape=(1,), dtype=tf.float32)}# train_files就是tfrecord文件列表tf.data.experimental.make_batched_features_dataset( file_pattern=train_files, features=feature_schema, label_key="label")
1.2、tfrecord中将各个特征数据放在一个数组中的情况("user_data"—array[值])
参考:
2、定义feature_column
def get_feature_columns(args): user_sex = tf.feature_column.categorical_column_with_identity(key="sex", num_buckets=3, default_value=0) user_age = tf.feature_column.categorical_column_with_identity(key="age", num_buckets=9, default_value=0) return [user_sex,user_age]
3、使用已有的estimator模型
dnn_model = estimator_lib.DNNClassifier(hidden_units=[...], model_dir=FLAGS.model_dir, feature_columns=get_feature_columns([...]),...)
二、使用input_layer方式用于自定义estimator模型
1、读取tfrecord数据
方式同上
2、定义feature_column
方式同上
3、在自定义estimator中使用feature_column.input_layer方法
model_fn为自定义estimator的主体部分
# feature_columns_new 为2中定义的feature_column# features 为从tfrecord中获取的tensor# inputs_layers 为神经网络层的输入def model_fn(features,labels,mode): inputs_layers =tf.feature_column.input_layer(features,feature_columns_new) # 可以多次使用input_layer ...
4、使用自定义estimator
# 不需要feature及feature_column及feature_schema作为参数输入dnn_model = MyEstimator( model_dir=FLAGS.model_dir, optimizer = ..., hidden_units = [...])# 通过下面方法直接从train_spec输出到自定义的estimator的model_fn方法中tf.estimator.train_and_evaluate(estimator=dnn_model,train_spec=train_spec,eval_spec=eval_spec)
5、问题说明
以上使用的版本为TensorFlow1.x,对于TensorFlow2.x修改如下:
from tensorflow.python.feature_column import feature_column as fc_v1rs = fc_v1.input_layer(features=feaDics,feature_columns=get_feature_columns_new())# 或者:rs = tf.compat.v1.feature_column.input_layer(features=feaDics,feature_columns=get_feature_columns_new())
三、自定义输入层用于自定义estimator模型
1、读取tfrecord数据
方式同上
2、定义feature_column
方式同上
3、重写自定义的inputlayer输入层
参考:
# 重写自定义层的标准方式:class myInputLayer(tf.keras.layers.Layer): def __init__(self,feature_columns_1, feature_columns_2, trainable=True, name=None, **kwargs): super(myInputLayer,self).__init__(trainable=trainable,name=name,**kwargs) self._feature_columns_1 = feature_columns_1 self._feature_columns_2 = feature_columns_2 self._state_magager = fc_v2._StateManagerImpl(self,trainable) def build(self, input_shape): with tf.variable_scope(self.name): for column in self._feature_columns_1: with tf.variable_scope(column.name): column.create_state(self._state_magager) for column in self._feature_columns_2: with tf.variable_scope(column.name): column.create_state(self._state_magager) super(myInputLayer,self).build(None) def call(self, inputs, **kwargs): transformation_catch = fc_v2.FeatureTransformationCache(inputs) output_tensors = [] for column in self._feature_columns_1 + self._feature_columns_2: with tf.name_scope(column.name): tensor = column.get_dense_tensor(transformation_catch,self._state_magager) num_elements = column.variable_shape.num_elements() batch_size = tf.shape(tensor)[0] output_tensor = tf.reshape(tensor,shape=(batch_size,num_elements)) output_tensors.append(output_tensor) return output_tensors
4、在estimator中如何使用
def model_fn(features,labels,mode): inputnet = myInputLayer(feature_columns_new[0],features,feature_columns_new[1],name="inputlayer") rs = inputnet(features) # rs之后用于dnn层
转载地址:https://blog.csdn.net/h_jlwg6688/article/details/116700367 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
第一次来,支持一个
[***.219.124.196]2024年03月15日 06时40分06秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
如何修改手机屏幕显示的长宽比例_屏幕分辨率 尺寸 比例 长宽 如何计算
2019-04-21
mysql 的版本 命名规则_MySQL版本和命名规则
2019-04-21
no java stack_Java Stack contains()用法及代码示例
2019-04-21
java动态代码_Java Agent入门学习之动态修改代码
2019-04-21
python集合如何去除重复数据_Python 迭代删除重复项,集合删除重复项
2019-04-21
java 验证码校验_JavaWeb验证码校验功能代码实例
2019-04-21
java多线程初学者指南_Java多线程初学者指南(4):线程的生命周期
2019-04-21
java添加资源文件_如何在eclipse中将资源文件夹添加到我的Java项目中
2019-04-21
java的三种修饰符_3分钟弄明白JAVA三大修饰符
2019-04-21
PHP字符串运算结果,PHP运算符(二)"字符串运算符"实例详解
2019-04-21
PHP实现 bcrypt,如何使php中的bcrypt和Java中的jbcrypt兼容
2019-04-21
php8安全,PHP八大安全函数解析
2019-04-21
php基础语法了解和熟悉的表现,PHP第二课 了解PHP的基本语法以及目录结构
2019-04-21
matlab中lag函数用法,MATLAB movavg函数用法
2019-04-21
matlab变形监测,基于matlab的变形监测数据处理与分析_毕业设计论文
2019-04-21