tensorflow中的加权成本函数

tensorflow中的加权成本函数,tensorflow,Tensorflow,我试图在以下成本函数中引入权重: _cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y)) 但不必自己做softmax交叉熵。所以我想把成本计算分解成成本1和成本2,并将我的logits和y值的修改版本输入到每一个 我想这样做,但不确定正确的代码是什么: mask=(y==0) y0 = tf.boolean_mask(y,mask)*y1Weight (这

我试图在以下成本函数中引入权重:

_cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y))
但不必自己做softmax交叉熵。所以我想把成本计算分解成成本1和成本2,并将我的logits和y值的修改版本输入到每一个

我想这样做,但不确定正确的代码是什么:

mask=(y==0)
y0 = tf.boolean_mask(y,mask)*y1Weight

(这给出了掩码不能是标量的错误)

您可以按如下方式计算加权成本;使用预定义的
每类权重
张量和形状
(num\u类,1)
。对于标签,请使用编码

# here labels shape should be [batch_size, num_classes] ; obtained using one_hot
_cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y)

# Here you can define a deterministic weights tensor. 
# weights_per_class = tf.constant(np.array([y0weights, y1weights, ...]))
weights_per_class =tf.random_normal(shape=(num_classes, 1), dtype=tf.float32)

# Use the weights tensor to compute weighted loss
_weighted_cost =  tf.reduce_mean(tf.matmul(_cost, weights_per_class))

可以使用
tf.where
计算权重掩码。以下是加权成本示例:

batch_size = 100
y1Weight = 0.25
y0Weight = 0.75


_logits = tf.Variable(tf.random_normal(shape=(batch_size, 2), stddev=1.))
y = tf.random_uniform(shape=(batch_size,), maxval=2, dtype=tf.int32)

_cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y)

#Weight mask, the weights for label=0 is y0Weight and for 1 is y1Weight
y_w = tf.where(tf.cast(y, tf.bool), tf.ones((batch_size,))*y0Weight, tf.ones((batch_size,))*y1Weight)

# New weighted cost
cost_w = tf.reduce_mean(tf.multiply(_cost, y_w))

正如@user1761806所建议的,更简单的解决方案是使用
tf.loss.sparse\u softmax\u cross\u entropy()
,它允许对类进行加权。

谢谢-我还发现,sparse\u softmax\u cross\u entropy是一个包装函数,它利用sparse\u softmax\u cross\u entropy\u和\u logits,但允许您提供权重。所以我就用了这个<代码>损耗=tf.add(1,tf.multiply(tf.cast(tf.equal(y,1),'int32'),1.5))\u成本=tf.reduce平均值(tf.loss.sparse\u softmax\u交叉熵(logits=\u logits,labels=y,weights=lossW))您不应该在
tf.where
中切换
y0Weight
?在
tf中,其中(条件,x,y)
x
用于情况
条件
,这意味着这里的
标签=1