Python Tensorflow-条件训练

Python Tensorflow-条件训练,python,tensorflow,if-statement,Python,Tensorflow,If Statement,我正在用tensorflow(1.12)以有监督的方式训练神经网络。我只想训练一些具体的例子。示例是通过剪切子序列动态创建的,因此我想在tensorflow中进行条件处理 这是我最初的代码部分: train_step, gvs = minimize_clipped(optimizer, loss, clip_value=FLAGS.gradient_clip, return

我正在用tensorflow(1.12)以有监督的方式训练神经网络。我只想训练一些具体的例子。示例是通过剪切子序列动态创建的,因此我想在tensorflow中进行条件处理

这是我最初的代码部分:

train_step, gvs = minimize_clipped(optimizer, loss,
                               clip_value=FLAGS.gradient_clip,
                               return_gvs=True)
gradients = [g for (g,v) in gvs]
gradient_norm = tf.global_norm(gradients)
tf.summary.scalar('gradients/norm', gradient_norm)
eval_losses = {'loss1': loss1,
               'loss2': loss2}
培训步骤稍后执行为:

batch_eval, _ = sess.run([eval_losses, train_step])
我在考虑插入类似的内容

train_step_fake = ????
eval_losses_fake = tf.zeros_like(tensor)
train_step_new = tf.cond(my_cond, train_step, train_step_fake)
eval_losses_new = tf.cond(my_cond, eval_losses, eval_losses_fake)
然后做什么

batch_eval, _ = sess.run([eval_losses, train_step])
然而,我不知道如何创造一个假的火车步

另外,总体而言,这是一个好主意还是有一种更顺畅的方法?我使用的是tfrecords管道,但没有其他高级模块(如keras、tf.estimator、eager execution等)


我们非常感谢您的帮助

先回答具体问题。当然,只可能根据
tf.cond
结果执行训练步骤。请注意,第2个和第3个参数是lambdas,但更像:

train_step_new = tf.cond(my_cond, lambda: train_step, lambda: train_step_fake)
eval_losses_new = tf.cond(my_cond, lambda: eval_losses, lambda: eval_losses_fake)
不过,你的直觉是,这可能不是正确的做法

更可取的做法是(无论是从效率还是从阅读和推理代码的角度)在数据到达模型之前过滤掉您想要忽略的数据

这是您可以通过使用来实现的。它有一个非常有用的
filter()
方法供您使用。如果您现在正在使用dataset api读取TFRecords,那么这应该非常简单,只需添加以下内容:

dataset = dataset.filter(lambda x: {whatever op you were going to use in tf.cond})

如果你还没有使用DataSet API,现在可能是时候对它进行一点点的阅读,而不是把模型用那个<代码> TF.COND()/<代码>作为过滤器来处理。