gpt4 book ai didi

python - 是否建议使用相同的 torch Dataset 类进行训练和预测?

转载 作者:行者123 更新时间:2023-12-03 18:47:51 24 4
gpt4 key购买 nike

我最近开始使用 PyTorch,我喜欢它的面向对象风格。但是,我想知道在预测模型时什么是最好的和建议的工作流程。我想使用我编写的自定义 Dataset 类,并将其用于训练和验证我的模型。这个类是一个 map 风格的数据集,因此我实现了__getitem__返回图像和目标的方法:

class CustomDataset:

def __init__(self, ...):
...

def __getitem__(self, image_id):
....
return (
torch.tensor(image, dtype=torch.float),
torch.tensor(target, dtype=torch.long),
)
然而,当我使用这个类进行预测时,我没有任何返回的目标。我目前的解决方法是这样的
def __getitem__(self, image_id):
....
if predict:
return (
torch.tensor(image, dtype=torch.float),
np.nan,
)
else:
return (
torch.tensor(image, dtype=torch.float),
torch.tensor(target, dtype=torch.long),
)
但是,我想知道是否有更好的方法来做到这一点。同时,由于感觉有点不自然,我开始怀疑是否建议使用同一个类进行训练和预测(应该是,但我的解决方案的笨拙让我怀疑)。当然,我根本无法返回元组,而只能返回第一个元素,但这仍然需要 if-else。

最佳答案

PyTorch 的 DataSet类真的很简单。所以,不要想太多。它只不过是访问数据的包装器。
您不必返回元组,甚至不需要返回张量。你可以返回任何你想要的数据。通常,它将采用以下样式之一:

  • 对于无监督数据:Sample(Sample, None)
  • 对于监督数据:(Sample, Label)
  • 对于具有多个目标的监督数据,例如物体检测:(Sample, [Label1, Label2, ...])(Sample, Label1, Label2, ...)

  • 也是 common使用相同的 DataSet 类进行训练/测试。
    因此,在您的情况下,只需返回样本或元组 (sample, None) as done in torchvision并相应地调整您的管道。我不建议使用 np.nan因为它会失败一个简单的 None 检查( np.nan == None )。另外,我鼓励你继承 torch.data.Dataset .
    但是,如果您的管道强制您使用元组或有其他限制,我建议您重新表述您的问题。

    关于python - 是否建议使用相同的 torch Dataset 类进行训练和预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67445508/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com