gpt4 book ai didi

pytorch - 如何计算pytorch中BCEWithLogitsLoss的不平衡权重

转载 作者:行者123 更新时间:2023-12-03 17:08:28 24 4
gpt4 key购买 nike

我正在尝试使用 270 解决一个多标签问题标签,我已将目标标签转换为一种热编码形式。我正在使用 BCEWithLogitsLoss() .由于训练数据不平衡,我使用 pos_weight争论,但我有点困惑。
pos_weight (张量,可选)——正例的权重。必须是长度等于类数的向量。

我是否需要将每个标签的正值的总数作为张量给出,或者它们的权重意味着其他东西?

最佳答案

PyTorch documentation for BCEWithLogitsLoss建议 pos_weight 是每个类的负计数和正计数之间的比率。
所以,如果 len(dataset)是 1000,多热编码的元素 0 有 100 个正计数,然后是 pos_weights_vector 的元素 0应该是 900/100 = 9 .这意味着二元交叉损失将表现为数据集包含 900 个正样本而不是 100 个。
这是我的实现:

  def calculate_pos_weights(class_counts):
pos_weights = np.ones_like(class_counts)
neg_counts = [len(data)-pos_count for pos_count in class_counts]
for cdx, pos_count, neg_count in enumerate(zip(class_counts, neg_counts)):
pos_weights[cdx] = neg_count / (pos_count + 1e-5)

return torch.as_tensor(pos_weights, dtype=torch.float)
哪里 class_counts只是正样本的列式总和。我 posted it在 PyTorch 论坛上,一位 PyTorch 开发人员给了它祝福。

关于pytorch - 如何计算pytorch中BCEWithLogitsLoss的不平衡权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57021620/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com