gpt4 book ai didi

tensorflow - TF2.0 中 Keras 损失中 `sample_weight` 参数的奇怪形状要求

转载 作者:行者123 更新时间:2023-12-02 04:23:27 32 4
gpt4 key购买 nike

根据TF文件,sample_weight参数可以有形状 [batch_size] .相关文档引用如下:

sample_weight: Optional Tensor whose rank is either 0, or the same rank as y_true, or is broadcastable to y_true. sample_weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sample_weight is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sample_weight vector. If the shape of sample_weight matches the shape of y_pred, then the loss of each measurable element of y_pred is scaled by the corresponding value of sample_weight.



但是,我不明白为什么下面的代码不起作用。
import tensorflow as tf

gt = tf.convert_to_tensor([1, 1, 1, 1, 1])
pred = tf.convert_to_tensor([1., 0., 1., 1., 0.])
sample_weights = tf.convert_to_tensor([0, 1, 0, 0, 0])

loss = tf.keras.losses.BinaryCrossentropy()(gt, pred, sample_weight=sample_weights)
print(loss)

代码抛出这个错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Can not squeeze dim[0], expected a dimension of 1, got 5 [Op:Squeeze]



如果我扩大 gt 的维度, pred , 和 sample_weights ,然后它正常工作并输出3.0849898的预期损失值。
import tensorflow as tf

gt = tf.convert_to_tensor([1, 1, 1, 1, 1])
pred = tf.convert_to_tensor([1., 0., 1., 1., 0.])
sample_weights = tf.convert_to_tensor([0, 1, 0, 0, 0])

# expand dims
gt = tf.expand_dims(gt, 1)
pred = tf.expand_dims(pred, 1)
sample_weights = tf.expand_dims(sample_weights, 1)

loss = tf.keras.losses.BinaryCrossentropy()(gt, pred, sample_weight=sample_weights)
print(loss) # loss is 3.0849898

最佳答案

问题不在于sample_weight形状。它是 predgt形状应该是 [batch_size, n_labels] :

import tensorflow as tf

gt = tf.convert_to_tensor([1, 1, 1, 1, 1])
pred = tf.convert_to_tensor([1., 0., 1., 1., 0.])
sample_weights = tf.convert_to_tensor([0, 1, 0, 0, 0])

# expand dims
gt = tf.expand_dims(gt, 1)
pred = tf.expand_dims(pred, 1)
print(gt.shape, pred.shape) #(5, 1) (5, 1)

loss = tf.keras.losses.BinaryCrossentropy()(gt, pred, sample_weight=sample_weights)
print(loss) # loss is 3.0849898

关于tensorflow - TF2.0 中 Keras 损失中 `sample_weight` 参数的奇怪形状要求,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57541534/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com