TensorFlow系列——在estimator中使用feature_column处理tfrecord特征
发布日期:2021-09-30 09:33:46 浏览次数:1 分类:技术文章

本文共 3929 字,大约阅读时间需要 13 分钟。

一、用于现成的estimator模型

1、读取tfrecord数据

1.1、tfrecord中包含所有(特征名-值)以及标签的情况

feature_schema = {    # 包含了tfrecord里的所有特征,包括标签label    "sex": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64),    "age": tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64),    "label": tf.io.FixedLenFeature(shape=(1,), dtype=tf.float32)}# train_files就是tfrecord文件列表tf.data.experimental.make_batched_features_dataset(            file_pattern=train_files,            features=feature_schema,            label_key="label")

1.2、tfrecord中将各个特征数据放在一个数组中的情况("user_data"—array[值])

参考:

2、定义feature_column

def get_feature_columns(args):    user_sex = tf.feature_column.categorical_column_with_identity(key="sex", num_buckets=3, default_value=0)    user_age = tf.feature_column.categorical_column_with_identity(key="age", num_buckets=9, default_value=0)    return [user_sex,user_age]

3、使用已有的estimator模型

dnn_model = estimator_lib.DNNClassifier(hidden_units=[...],                                            model_dir=FLAGS.model_dir,                                            feature_columns=get_feature_columns([...]),...)

二、使用input_layer方式用于自定义estimator模型

1、读取tfrecord数据

方式同上

2、定义feature_column

方式同上

3、在自定义estimator中使用feature_column.input_layer方法

model_fn为自定义estimator的主体部分

# feature_columns_new 为2中定义的feature_column# features 为从tfrecord中获取的tensor# inputs_layers 为神经网络层的输入def model_fn(features,labels,mode):    inputs_layers =tf.feature_column.input_layer(features,feature_columns_new)    # 可以多次使用input_layer    ...

4、使用自定义estimator

# 不需要feature及feature_column及feature_schema作为参数输入dnn_model =  MyEstimator(        model_dir=FLAGS.model_dir,        optimizer = ...,        hidden_units = [...])# 通过下面方法直接从train_spec输出到自定义的estimator的model_fn方法中tf.estimator.train_and_evaluate(estimator=dnn_model,train_spec=train_spec,eval_spec=eval_spec)

5、问题说明

以上使用的版本为TensorFlow1.x,对于TensorFlow2.x修改如下:

from tensorflow.python.feature_column import feature_column as fc_v1rs = fc_v1.input_layer(features=feaDics,feature_columns=get_feature_columns_new())# 或者:rs = tf.compat.v1.feature_column.input_layer(features=feaDics,feature_columns=get_feature_columns_new())

三、自定义输入层用于自定义estimator模型

1、读取tfrecord数据

方式同上

2、定义feature_column

方式同上

3、重写自定义的inputlayer输入层

参考:

# 重写自定义层的标准方式:class myInputLayer(tf.keras.layers.Layer):    def __init__(self,feature_columns_1,                 feature_columns_2,                 trainable=True,                 name=None,                 **kwargs):        super(myInputLayer,self).__init__(trainable=trainable,name=name,**kwargs)        self._feature_columns_1 = feature_columns_1        self._feature_columns_2 = feature_columns_2        self._state_magager = fc_v2._StateManagerImpl(self,trainable)    def build(self, input_shape):        with tf.variable_scope(self.name):            for column in  self._feature_columns_1:                with tf.variable_scope(column.name):                    column.create_state(self._state_magager)            for column in self._feature_columns_2:                with tf.variable_scope(column.name):                    column.create_state(self._state_magager)        super(myInputLayer,self).build(None)    def call(self, inputs, **kwargs):        transformation_catch = fc_v2.FeatureTransformationCache(inputs)        output_tensors = []        for column in self._feature_columns_1 + self._feature_columns_2:            with tf.name_scope(column.name):                tensor = column.get_dense_tensor(transformation_catch,self._state_magager)                num_elements = column.variable_shape.num_elements()                batch_size = tf.shape(tensor)[0]                output_tensor = tf.reshape(tensor,shape=(batch_size,num_elements))                output_tensors.append(output_tensor)        return output_tensors

4、在estimator中如何使用

def model_fn(features,labels,mode):    inputnet = myInputLayer(feature_columns_new[0],features,feature_columns_new[1],name="inputlayer")    rs = inputnet(features)    # rs之后用于dnn层

 

转载地址:https://blog.csdn.net/h_jlwg6688/article/details/116700367 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:TensorFlow系列——环境相关
下一篇:tensorflow系列——读取tfrecord数据

发表评论

最新留言

第一次来,支持一个
[***.219.124.196]2024年03月15日 06时40分06秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章

如何修改手机屏幕显示的长宽比例_屏幕分辨率 尺寸 比例 长宽 如何计算 2019-04-21
mysql 的版本 命名规则_MySQL版本和命名规则 2019-04-21
no java stack_Java Stack contains()用法及代码示例 2019-04-21
java动态代码_Java Agent入门学习之动态修改代码 2019-04-21
python集合如何去除重复数据_Python 迭代删除重复项,集合删除重复项 2019-04-21
iview 自定义时间选择器组件_Vue.js中使用iView日期选择器并设置开始时间结束时间校验功能... 2019-04-21
java 验证码校验_JavaWeb验证码校验功能代码实例 2019-04-21
java多线程初学者指南_Java多线程初学者指南(4):线程的生命周期 2019-04-21
java进程user是jenkins_java 学习:在java中启动其他应用,由jenkins想到的 2019-04-21
java添加资源文件_如何在eclipse中将资源文件夹添加到我的Java项目中 2019-04-21
java的三种修饰符_3分钟弄明白JAVA三大修饰符 2019-04-21
mysql source skip_redis mysql 中的跳表(skip list) 查找树(btree) 2019-04-21
java sun.org.mozilla_maven编译找不到符号 sun.org.mozilla.javascript.internal 2019-04-21
php curl 输出到文件,PHP 利用CURL(HTTP)实现服务器上传文件至另一服务器 2019-04-21
PHP字符串运算结果,PHP运算符(二)"字符串运算符"实例详解 2019-04-21
PHP实现 bcrypt,如何使php中的bcrypt和Java中的jbcrypt兼容 2019-04-21
php8安全,PHP八大安全函数解析 2019-04-21
php基础语法了解和熟悉的表现,PHP第二课 了解PHP的基本语法以及目录结构 2019-04-21
matlab中lag函数用法,MATLAB movavg函数用法 2019-04-21
matlab变形监测,基于matlab的变形监测数据处理与分析_毕业设计论文 2019-04-21