If statement keras(tensorflow后端)带K.switch()的条件赋值

If statement keras(tensorflow后端)带K.switch()的条件赋值,if-statement,tensorflow,keras,switch-statement,backend,If Statement,Tensorflow,Keras,Switch Statement,Backend,我正在尝试实现类似的功能 if np.max(subgrid) == np.min(subgrid): middle_middle = cur_subgrid + 1 else: middle_middle = cur_subgrid 因为条件只能在运行时确定,所以我使用Keras语法如下 middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambd

我正在尝试实现类似的功能

if np.max(subgrid) == np.min(subgrid):
    middle_middle = cur_subgrid + 1
else:
    middle_middle = cur_subgrid
因为条件只能在运行时确定,所以我使用Keras语法如下

middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
但我得到了一个错误:

<ipython-input-112-0504ce070e71> in col_loop(j, gray_map, mask_A)
     56 
     57 
---> 58             middle_middle = K.switch(K.max(subgrid) == K.min(subgrid), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
     59 
     60             print ('ml',middle_left.shape)

/nfs/isicvlnas01/share/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in switch(condition, then_expression, else_expression)    2561         The selected tensor.    2562     """
-> 2563     if condition.dtype != tf.bool:    2564         condition = tf.cast(condition, 'bool')    2565     if not callable(then_expression):

AttributeError: 'bool' object has no attribute 'dtype'
在col\u循环中(j、灰度图、遮罩图)
56
57
--->58 middle_middle=K.switch(K.max(子网格)=K.min(子网格),lambda:tf.add(cur_子网格,1),lambda:cur_子网格)
59
60号打印('ml',中左形状)
/nfs/isicvlnas01/share/anaconda3/lib/python3.5/site-packages/keras/backend/tensorflow\u backend.py在switch(condition、then\u expression、else\u expression)中2561选择的张量。2562     """
->2563 if condition.dtype!=tf.bool:2564 condition=tf.cast(condition,'bool')2565 if不可调用(then_表达式):
AttributeError:“bool”对象没有属性“dtype”

middle\u middle
cur\u subgrid
和subgrid都是
NxN
张量。非常感谢您的帮助。

我认为问题在于
K.max(subgrid)=K.min(subgrid)
您正在创建一个比较两个张量对象的python布尔,而不是一个包含两个输入张量值比较值的tensorflow布尔张量

换句话说,您所写的内容将被评估为

K.switch(False, lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
而不是

comparison = ... # Some tensor, that at runtime will contain True if min and max are the same, False otherwise. 
K.switch(comparison , lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)
因此,您需要做的是使用而不是
=

K.switch(K.equal(K.max(subgrid),K.min(subgrid)), lambda: tf.add(cur_subgrid,1), lambda: cur_subgrid)

很抱歉,我现在不能测试这个,所以我希望它能工作。非常感谢!这正是原因。