Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow 使用tf.contrib.learn模块时如何创建tf.RunMetadata并将其添加到writer_Tensorflow - Fatal编程技术网

Tensorflow 使用tf.contrib.learn模块时如何创建tf.RunMetadata并将其添加到writer

Tensorflow 使用tf.contrib.learn模块时如何创建tf.RunMetadata并将其添加到writer,tensorflow,Tensorflow,现在我使用tf.contrib.learn.Experiment、assistator、learn\u runner来帮助训练模型。当运行learn\u runner时,它将隐式地创建一个tf.MoniteredSession并调用其run()函数,因此我无法将参数选项和run\u元数据添加到run()函数中 那么,如何将选项和运行元数据参数添加到运行函数并调用摘要编写器。添加运行元数据() 我在网上搜索了很长时间。但是没有用。请帮助或尝试给出一些如何实现这一点的想法 代码如下: from __

现在我使用
tf.contrib.learn.Experiment、assistator、learn\u runner
来帮助训练模型。当运行
learn\u runner
时,它将隐式地创建一个
tf.MoniteredSession
并调用其
run()
函数,因此我无法将参数
选项
run\u元数据
添加到
run()
函数中

那么,如何将
选项
运行元数据
参数添加到
运行
函数并调用
摘要编写器。添加运行元数据()

我在网上搜索了很长时间。但是没有用。请帮助或尝试给出一些如何实现这一点的想法

代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from tensorflow.contrib import slim, training, learn

tf.logging.set_verbosity(tf.logging.DEBUG)


def variable_summaries(var):
    with tf.name_scope(var.name.split(':')[0]):
        mean = tf.reduce_mean(var)
        with tf.name_scope('stddev'):
            stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('mean', mean))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('stddev', stddev))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('max', tf.reduce_max(var)))
        tf.add_to_collection('variable_summaries', tf.summary.scalar('min', tf.reduce_min(var)))
        tf.add_to_collection('variable_summaries', tf.summary.histogram('histogram', var))


def model_fn(features, labels, mode, params):
    id_ts = tf.get_collection('id_ts')[0]
    fc1 = slim.fully_connected(features, 10, tf.nn.relu, scope='fc1')
    variable_summaries(fc1)
    fc2 = slim.fully_connected(fc1, 2, None, scope='fc2')
    variable_summaries(fc2)

    for i in tf.trainable_variables():
        variable_summaries(i)

    logits = fc2
    prob = tf.nn.softmax(logits)
    predictions = tf.argmax(logits, axis=1)

    summay_op = tf.summary.merge_all('variable_summaries')
    scaffold = tf.train.Scaffold(summary_op=summay_op)

    if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
        onehot_labels = slim.one_hot_encoding(labels, 2)
        loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=onehot_labels)

        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

        train_op = optimizer.minimize(loss, slim.get_global_step())

        eval_metric_ops = {
            'accuracy': tf.metrics.accuracy(labels, predictions),
            'auc': tf.metrics.auc(labels, predictions),
            'precision': tf.metrics.precision(labels, predictions),
            'recall': tf.metrics.recall(labels, predictions),
        }

        return learn.ModelFnOps(mode=mode,
                                predictions=predictions,
                                loss=loss,
                                train_op=train_op,
                                eval_metric_ops=eval_metric_ops,
                                scaffold=scaffold)
    elif mode == learn.ModeKeys.INFER:
        return learn.ModelFnOps(mode=mode, predictions={'prob': prob,
                                                        'fc1': fc1,
                                                        'fc2': fc2,
                                                        'id': id_ts})


def train_input_fn():
    fn = tf.train.string_input_producer(['data.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(fn)
    data_ts = tf.decode_csv(value, [[0.], [0.], [0.], [0.]], field_delim=',')
    batch_ts = tf.train.shuffle_batch(data_ts, 10, 1000, 10)
    id_ts = batch_ts[0]
    tf.add_to_collection('id_ts', id_ts)
    features_ts = tf.concat([tf.reshape(batch_ts[1], [-1, 1]), tf.reshape(batch_ts[2], [-1, 1])], axis=1)
    labels_ts = tf.cast(batch_ts[3], tf.int32)
    return features_ts, labels_ts


def eval_input_fn():
    fn = tf.train.string_input_producer(['data.csv'])
    reader = tf.TextLineReader()
    key, value = reader.read(fn)
    data_ts = tf.decode_csv(value, [[0.], [0.], [0.], [0.]], field_delim=',')
    batch_ts = tf.train.batch(data_ts, 10, 1000)
    id_ts = batch_ts[0]
    tf.add_to_collection('id_ts', id_ts)
    features_ts = tf.concat([tf.reshape(batch_ts[1], [-1, 1]), tf.reshape(batch_ts[2], [-1, 1])], axis=1)
    labels_ts = tf.cast(batch_ts[3], tf.int32)
    return features_ts, labels_ts


def run_experiment(_):
    session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True),
                                    log_device_placement=False)

    run_config = learn.RunConfig(save_checkpoints_steps=100,
                                 model_dir='model_dir',
                                 session_config=session_config,
                                 keep_checkpoint_max=2)

    hparams = training.HParams(train_steps=1000)

    learn.learn_runner.run(experiment_fn=create_experiment_fn,
                           schedule='train_and_evaluate',
                           run_config=run_config,
                           hparams=hparams)


def create_experiment_fn(run_config, hparams):
    estimator = get_estimator_fn(config=run_config, params=hparams)
    return learn.Experiment(estimator=estimator,
                            train_input_fn=train_input_fn,
                            eval_input_fn=eval_input_fn,
                            train_steps=hparams.train_steps)


def get_estimator_fn(config, params):
    return learn.Estimator(model_fn=model_fn,
                           model_dir=config.model_dir,
                           config=config,
                           params=params)


if __name__ == '__main__':
    tf.app.run(main=run_experiment)

你弄明白了吗?