Python 如何在Keras密集层中共享权重而不产生偏差
我正试图创建一个有序回归模型,正如下面所解释的。它的一个主要部分是在最后一层中共享权重,但不是为了获得秩单调性(基本上是为了确保对于任何这样的N,P[Y>N]必须始终大于P[Y>N-1])。这对我来说是非常理想的,因为我有几个值,其中只有很少的值,但我仍然希望得到它们的概率。到目前为止,我已经实现了它编码数字的方式,并且没有像有时P(Y>5)>P(Y>4)的概率那样的秩单调性 我怎样才能在Keras中实现权重共享而不是偏差共享?我知道函数式API有一种共享权重和偏差的方法,但在这种情况下没有帮助。感谢所有能帮忙的人 编辑:在一个层内与N个神经元和N个层之间共享权重但不共享偏差都可以解决问题。另外,我认为将Dense()中的use_bias参数设置为false并创建某种类型的自定义bias层也可以解决问题,但我不确定如何做到这一点 我认为六个神经元和五个输入的方程式是这样的Python 如何在Keras密集层中共享权重而不产生偏差,python,tensorflow,keras,deep-learning,neural-network,Python,Tensorflow,Keras,Deep Learning,Neural Network,我正试图创建一个有序回归模型,正如下面所解释的。它的一个主要部分是在最后一层中共享权重,但不是为了获得秩单调性(基本上是为了确保对于任何这样的N,P[Y>N]必须始终大于P[Y>N-1])。这对我来说是非常理想的,因为我有几个值,其中只有很少的值,但我仍然希望得到它们的概率。到目前为止,我已经实现了它编码数字的方式,并且没有像有时P(Y>5)>P(Y>4)的概率那样的秩单调性 我怎样才能在Keras中实现权重共享而不是偏差共享?我知道函数式API有一种共享权重和偏差的方法,但在这种情况下没有帮助
op1 = w1z1 + w2z2 + w3z3 + w4z4 + w5z5 + b1
op2 = w1z1 + w2z2 + w3z3 + w4z4 + w5z5 + b2
op3 = w1z1 + w2z2 + w3z3 + w4z4 + w5z5 + b3
op4 = w1z1 + w2z2 + w3z3 + w4z4 + w5z5 + b4
op5 = w1z1 + w2z2 + w3z3 + w4z4 + w5z5 + b5
op6 = w1z1 + w2z2 + w3z3 + w4z4 + w5z5 + b6
其中w1到w5是权重,z1到z5是输入,b1到b6是偏差项,实现这一点的方法之一是定义自定义的
偏差
层,下面是如何做到这一点的。
PS:根据需要更改输入形状/初始值设定项
import tensorflow as tf
print('TensorFlow:', tf.__version__)
class BiasLayer(tf.keras.layers.Layer):
def __init__(self, units, *args, **kwargs):
super(BiasLayer, self).__init__(*args, **kwargs)
self.bias = self.add_weight('bias',
shape=[units],
initializer='zeros',
trainable=True)
def call(self, x):
return x + self.bias
z1 = tf.keras.Input(shape=[1])
z2 = tf.keras.Input(shape=[1])
z3 = tf.keras.Input(shape=[1])
z4 = tf.keras.Input(shape=[1])
z5 = tf.keras.Input(shape=[1])
dense_layer = tf.keras.layers.Dense(units=10, use_bias=False)
op1 = BiasLayer(units=10)(dense_layer(z1))
op2 = BiasLayer(units=10)(dense_layer(z2))
op3 = BiasLayer(units=10)(dense_layer(z3))
op4 = BiasLayer(units=10)(dense_layer(z4))
op5 = BiasLayer(units=10)(dense_layer(z5))
model = tf.keras.Model(inputs=[z1, z2, z3, z4, z5], outputs=[op1, op2, op3, op4, op5])
model.summary()
输出:
TensorFlow: 2.1.0-dev20200107
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_3 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_4 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_5 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
dense (Dense) (None, 10) 10 input_1[0][0]
input_2[0][0]
input_3[0][0]
input_4[0][0]
input_5[0][0]
__________________________________________________________________________________________________
bias_layer (BiasLayer) (None, 10) 10 dense[0][0]
__________________________________________________________________________________________________
bias_layer_1 (BiasLayer) (None, 10) 10 dense[1][0]
__________________________________________________________________________________________________
bias_layer_2 (BiasLayer) (None, 10) 10 dense[2][0]
__________________________________________________________________________________________________
bias_layer_3 (BiasLayer) (None, 10) 10 dense[3][0]
__________________________________________________________________________________________________
bias_layer_4 (BiasLayer) (None, 10) 10 dense[4][0]
==================================================================================================
Total params: 60
Trainable params: 60
Non-trainable params: 0
__________________________________________________________________________________________________
你介意看看类似的(我猜)堆栈吗?