gpt4 book ai didi

python - 如何提取两个张量之间不等价条目的索引?

转载 作者:行者123 更新时间:2023-12-01 00:24:10 24 4
gpt4 key购买 nike

我有一个张量,其中包含 N 个对象类别的 N 个预测,并且我有另一个张量,其中包含真实的 N 个目标对象类别。我想提取我的分类器预测错误的张量索引。

考虑以下两个张量定义为:

import torch
predictions = torch.tensor([ [0], [1], [1], [0], [0], [1] ])
target = torch.tensor([ [0], [0], [1], [1], [0], [1] ])

我想找到一些函数,我可以在其中传递这两个向量并返回一个类似 index_diff = [1, 3] 的列表。有这个功能吗?我当前的想法是将这两个向量转换为 numpy 数组,然后循环 N 次并比较每个索引处的每个条目,但这对我来说似乎有点迂回。有替代方案吗?

最佳答案

类似的东西

index_diff = (predictions.flatten() != target.flatten()).nonzero().flatten()

应该可以。

关于python - 如何提取两个张量之间不等价条目的索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58732795/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com