Python 用tf.switch_案例训练模型网络的不同分支

Python 用tf.switch_案例训练模型网络的不同分支,python,tensorflow,machine-learning,keras,Python,Tensorflow,Machine Learning,Keras,我想创建一个神经网络,其中根据t_输入对网络的不同分支进行训练。因此,t_输入可以是0或1,并且取决于只训练正确的分支: import tensorflow as tf from tensorflow.keras.layers import Input, Lambda, Dense x = np.random.uniform(size=(10, 10)) t = np.random.binomial(100, 0.5) t_input = Input(batch_shape=(1,), dt

我想创建一个神经网络,其中根据t_输入对网络的不同分支进行训练。因此,t_输入可以是0或1,并且取决于只训练正确的分支:

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense

x = np.random.uniform(size=(10, 10))
t = np.random.binomial(100, 0.5)

t_input = Input(batch_shape=(1,), dtype='int32', name="t_input")
x_input = Input(shape=(x.shape[0]), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

x1 = lambda: x1
x2 = lambda: x2

r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)

# r = tf.case([(tf.equal(t_input, 1), x1), (tf.equal(t_input, 0), x2)], default=x2, exclusive=True)

model = tf.keras.models.Model(inputs=t_input, outputs=r)

print(model.predict([1]))
但是,我无法做到这一点,因为使用Kerastensor不够灵活:

Traceback (most recent call last):
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-59-92db0d55c181>", line 23, in <module>
    r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 952, in __call__
    input_list)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1091, in _functional_construction_call
    inputs, input_masks, args, kwargs)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 822, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 863, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\keras\layers\core.py", line 917, in call
    result = self.function(inputs, **kwargs)
  File "<ipython-input-59-92db0d55c181>", line 23, in <lambda>
    r = Lambda(lambda x: tf.switch_case(x, branch_fns={0: x1, 1: x2}))(t_input)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3616, in switch_case
    return _indexed_case_helper(branch_fns, default, branch_index, name)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3326, in _indexed_case_helper
    lower_using_switch_merge=lower_using_switch_merge)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\ops\cond_v2.py", line 1040, in indexed_case
    op_return_value=branch_index))
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 995, in func_graph_from_py_func
    expand_composites=True)
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\util\nest.py", line 659, in map_structure
    structure[0], [func(*x) for x in entries],
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\util\nest.py", line 659, in <listcomp>
    structure[0], [func(*x) for x in entries],
  File "C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 952, in convert
    (str(python_func), type(x)))
TypeError: To be compatible with tf.eager.defun, Python functions must return zero or more Tensors; in compilation of <function <lambda> at 0x000001ED0876EAF8>, found return value of type <class 'function'>, which is not a Tensor.
回溯(最近一次呼叫最后一次):
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\IPython\core\interactiveshell.py”,第3437行,运行代码
exec(代码对象、self.user\u全局、self.user\n)
文件“”,第23行,在
r=Lambda(Lambda x:tf.switch_case(x,branch_fns={0:x1,1:x2}))(t_输入)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\keras\engine\base\u layer.py”,第952行,在调用中__
输入(U列表)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\keras\engine\base\u layer.py”,第1091行,在功能结构调用中
输入、输入(U掩码、参数、kwargs)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\keras\engine\base\u layer.py”,第822行,在“keras\u tensor\u symbol”调用中
返回自我。推断输出签名(输入、参数、kwargs、输入掩码)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\keras\engine\base\u layer.py”,第863行,在“推断输出”签名中
输出=调用(输入,*args,**kwargs)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\keras\layers\core.py”,第917行,在调用中
结果=自身功能(输入,**kwargs)
文件“”,第23行,在
r=Lambda(Lambda x:tf.switch_case(x,branch_fns={0:x1,1:x2}))(t_输入)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\ops\control\u flow\u ops.py”,第3616行,在switch\u情况下
return\u index\u case\u helper(branch\u fns,默认值,branch\u index,name)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\ops\control\u flow\u ops.py”,第3326行,在“索引的”case\u helper中
使用\u开关\u合并降低\u=使用\u开关\u合并降低\u)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\ops\cond_v2.py”,第1040行,索引大小写
op_返回_值=分支_索引)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\framework\func_graph.py”,第995行,在函数图中
expand_composites=True)
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\util\nest.py”,第659行,映射结构
结构[0],[func(*x)表示条目中的x],
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\util\nest.py”,第659行,在
结构[0],[func(*x)表示条目中的x],
文件“C:\Users\gen06917\PycharmProjects\BaysianTarnet\.venv\lib\site packages\tensorflow\python\framework\func_graph.py”,第952行,转换格式
(str(python_func),类型(x)))
TypeError:为了与tf.eager.defun兼容,Python函数必须返回零个或多个张量;在的编译中,找到类型为的返回值,该类型不是张量。

我将您的tf.switch\u案例更改为keras开关,并在中输入两个单独的模型(您只在代码中输入其中一个),从而使您的代码正常工作。请注意,我必须平铺您的
t\u测试
输入,因为它希望两个输入具有相同的批量尺寸。我也不确定您是否想要np.random.binomial,因为它从二项分布中采样,几乎永远不会返回0。您可能应该查看
np.random.randint
,并将其值限制为0或1

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
import tensorflow.keras as K
import numpy as np

x_test = np.random.uniform(size=(10, 10))
t_test = np.array([np.random.binomial(100, 0.5)])

t_input = Input(shape=(1,), dtype=tf.int32, name="t_input")
x_input = Input(shape=(x_test.shape[1],), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

r = K.backend.switch(t_input,x1,x2)

model = tf.keras.models.Model(inputs=[t_input,x_input], outputs=r)

print(model.predict([np.tile(t_test,10),x_test]))

我找到了一种支持两个以上分支的方法:

import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda, Dense
import tensorflow.keras as K
import numpy as np

def h_chooser(inpus):
  t_inp, h_s = inpus
  h_out = tf.squeeze(tf.concat(h_s, axis=1))
  t_inds = tf.stack([tf.range(tf.size(t_inp)), tf.squeeze(t_inp)], axis=1)
  h_res = tf.gather_nd(h_out, t_inds)
  return h_res

x_test = np.random.uniform(size=(10, 10))
t_test = np.array([np.random.binomial(100, 0.5)])

t_input = Input(shape=(1,), dtype=tf.int32, name="t_input")
x_input = Input(shape=(x_test.shape[1],), name='x_input')

x = Dense(32)(x_input)

x1 = Dense(16)(x)
x1 = Dense(8)(x1)
x1 = Dense(1)(x1)

x2 = Dense(16)(x)
x2 = Dense(8)(x2)
x2 = Dense(1)(x2)

x3 = Dense(16)(x)
x3 = Dense(8)(x3)
x3 = Dense(1)(x3)

h_switch_case = Lambda(lambda x: h_chooser(x))

r = h_switch_case([t_input, [x1, x2, x3])

model = tf.keras.models.Model(inputs=[t_input,x_input], outputs=r)

print(model.predict([np.tile(t_test,10),x_test]))