Tensorflow回归模型不';不要从负面价值中学习

Tensorflow回归模型不';不要从负面价值中学习,tensorflow,Tensorflow,我使用tensorflow处理非平凡回归问题。 我的网络输入(包含192个二进制(0或1)元素的数组)表示扑克手牌、扑克下注和动作,输出(包含单个元素的数组)表示此动作的收益。 网络有3个隐藏层,包含100个神经元 input_layer = tf.placeholder(tf.float32, (None, self.INPUT_LAYER_SIZE)) layer_1_biases = tf.Variable(tf.truncated_normal(stddev=self.S

我使用tensorflow处理非平凡回归问题。 我的网络输入(包含192个二进制(0或1)元素的数组)表示扑克手牌、扑克下注和动作,输出(包含单个元素的数组)表示此动作的收益。 网络有3个隐藏层,包含100个神经元

    input_layer = tf.placeholder(tf.float32, (None, self.INPUT_LAYER_SIZE))
    layer_1_biases = tf.Variable(tf.truncated_normal(stddev=self.STD_DEV, shape=[self.LAYER_1_SIZE]))
    layer_1_weights = tf.Variable(tf.truncated_normal(stddev=self.STD_DEV, shape=[self.INPUT_LAYER_SIZE, self.LAYER_1_SIZE]))
    layer_2_biases = tf.Variable(tf.truncated_normal(stddev=self.STD_DEV, shape=[self.LAYER_2_SIZE]))
    layer_2_weights = tf.Variable(tf.truncated_normal(stddev=self.STD_DEV, shape=[self.LAYER_1_SIZE, self.LAYER_2_SIZE]))
    layer_3_biases = tf.Variable(tf.truncated_normal(stddev=self.STD_DEV, shape=[self.LAYER_3_SIZE]))
    layer_3_weights = tf.Variable(
        tf.truncated_normal(stddev=self.STD_DEV, shape=[self.LAYER_2_SIZE, self.LAYER_3_SIZE]))

    layer_1 = tf.add(tf.matmul(input_layer, layer_1_weights), layer_1_biases)
    layer_1 = tf.nn.relu(layer_1)
    layer_2 = tf.add(tf.matmul(layer_1, layer_2_weights), layer_2_biases)
    layer_2 = tf.nn.relu(layer_2)
    layer_3 = tf.add(tf.matmul(layer_2, layer_3_weights), layer_3_biases)
    layer_3 = tf.nn.relu(layer_3)

    output_layer_weights = tf.ones(shape=[self.LAYER_2_SIZE,1])
    output_layer = tf.matmul(layer3, output_layer_weights)

    return input_layer, output_layer
我尝试使用学习率0.001和GradientDescentOptimizer最小化平方误差:

self._target = tf.placeholder("float", [None])
self.cost = tf.reduce_mean(tf.square(self._target - tf.transpose(self.output_layer)))
self._train_operation = tf.train.GradientDescentOptimizer(self.LEARNING_RATE).minimize(self.cost)
经过培训(100次迭代),我达到了这个目标:

[-10.0, -10.0, -10.0, +10.0, -10.0]
此输出:

[0.08981139, 0.05091755, 0.04566674, 0.06034175, 9.99115811, 0.13543463]
在培训期间,我也没有在日志文件中看到任何负面输出。所以我的人际网络不会从负面价值中学习。有什么想法吗?

简单回答: 激活函数ReLU是:
max(0,x)
其中
x
是输入的加权和。所以函数总是正的。您应该使用其他激活功能(如
tanh