gpt4 book ai didi

python - 将类对象添加到 Pytorch Dataloader : batch must contain tensors

转载 作者:行者123 更新时间:2023-12-05 09:37:01 46 4
gpt4 key购买 nike

我有一个自定义 Pytorch 数据集,它返回一个包含类对象“查询”的字典。

class QueryDataset(torch.utils.data.Dataset):

def __init__(self, queries, values, targets):
super(QueryDataset).__init__()
self.queries = queries
self.values = values
self.targets = targets

def __len__(self):
return self.values.shape[0]

def __getitem__(self, idx):
sample = DeviceDict({'query': self.queries[idx],
"values": self.values[idx],
"targets": self.targets[idx]})
return sample

问题是,当我将查询放入数据加载器时,我得到了 default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'query.Query'> .有没有办法在我的数据加载器中有一个类对象?它在 next(iterator) 处爆炸在下面的代码中。

train_queries = QueryDataset(train_queries)
train_loader = torch.utils.data.DataLoader(train_queries,
batch_size=10],
shuffle=True,
drop_last=False)
for i in range(epochs):
iterator = iter(train_loader)
for i in range(len(train_loader)):
batch = next(iterator)
out = model(batch)
loss = criterion(out["pred"], batch["targets"])
self.optimizer.zero_grad()
loss.sum().backward()
self.optimizer.step()

最佳答案

你需要自己定义colate_fn为此。一个草率的方法只是为了向您展示这里的东西是如何工作的,就像这样:

import torch
class DeviceDict:
def __init__(self, data):
self.data = data

def print_data(self):
print(self.data)

class QueryDataset(torch.utils.data.Dataset):

def __init__(self, queries, values, targets):
super(QueryDataset).__init__()
self.queries = queries
self.values = values
self.targets = targets

def __len__(self):
return 5

def __getitem__(self, idx):
sample = {'query': self.queries[idx],
"values": self.values[idx],
"targets": self.targets[idx]}
return sample

def custom_collate(dict):
return DeviceDict(dict)

dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()

基本上 colate_fn 允许您实现自定义批处理或添加对自定义数据类型的支持,如我之前提供的链接中所述。
如您所见,它只是展示了概念,您需要根据自己的需要对其进行更改。

关于python - 将类对象添加到 Pytorch Dataloader : batch must contain tensors,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64586575/

46 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com