gpt4 book ai didi

python - 如何在pytorch中使用反向传播和自定义损失?

转载 作者:行者123 更新时间:2023-11-30 09:46:23 28 4
gpt4 key购买 nike

我正在尝试实现一个暹罗网络,该网络在两个图像之间存在排名损失。如果我定义自己的损失,我是否能够执行如下反向传播步骤?当我运行它时,有时在我看来它给出的结果与单个网络给出的结果相同。

with torch.set_grad_enabled(phase == 'train'):
outputs1 = model(inputs1)
outputs2 = model(inputs2)
preds1 = outputs1;
preds2 = outputs2;

alpha = 0.02;
w_r = torch.tensor(1).cuda(async=True);

y_i, y_j, predy_i, predy_j = labels1,labels2,outputs1,outputs2;
batchRankLoss = torch.tensor([max(0,alpha - delta(y_i[i], y_j[i])*predy_i[i] - predy_j[i])) for i in range(batchSize)],dtype = torch.float)
rankLossPrev = torch.mean(batchRankLoss)
rankLoss = Variable(rankLossPrev,requires_grad=True)

loss1 = criterion(outputs1, labels1)
loss2 = criterion(outputs2, labels2)


#total loss = loss1 + loss2 + w_r*rankLoss
totalLoss = torch.add(loss1,loss2)
w_r = w_r.type(torch.LongTensor)
rankLossPrev = rankLossPrev.type(torch.LongTensor)
mult = torch.mul(w_r.type(torch.LongTensor),rankLossPrev).type(torch.FloatTensor)
totalLoss = torch.add(totalLoss,mult.cuda(async = True));

# backward + optimize only if in training phase
if phase == 'train':
totalLoss.backward()
optimizer.step()

running_loss += totalLoss.item() * inputs1.size(0)

最佳答案

您有几行可以从构造函数或转换为另一种数据类型生成新的张量。执行此操作时,您将断开希望 backwards() 命令通过其进行区分的操作链。

此转换会断开图表,因为转换是不可微分的:

w_r = w_r.type(torch.LongTensor)

从构造函数构建张量将断开图的连接:

batchRankLoss = torch.tensor([max(0,alpha - delta(y_i[i], y_j[i])*predy_i[i] - predy_j[i])) for i in range(batchSize)],dtype = torch.float)

根据文档,将张量包装在变量中会将 grad_fn 设置为 None (也会断开图表的连接):

rankLoss = Variable(rankLossPrev,requires_grad=True)

假设您的 critereon 函数是可微分的,则梯度当前仅通过 loss1loss2 向后流动。您的其他渐变只会流动到 mult,然后才会通过调用 type() 停止。这与您的观察结果一致,即您的自定义损失不会改变神经网络的输出。

要允许梯度通过自定义损失向后流动,您必须编写相同的逻辑,同时避免 type() 转换并计算 rankLoss 而不使用列表理解力。

关于python - 如何在pytorch中使用反向传播和自定义损失?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52059399/

28 4 0
文章推荐: python - 如何让 CatBoost get_object_importance 与 AUC 配合使用?
文章推荐: java - Facelets:只需要一个页面的 标签
文章推荐: javascript - jQuery:带有关闭按钮的 fadeOut
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com