gpt4 book ai didi

python - Pytorch-类型错误 : 'torch.Size' object cannot be interpreted as an integer

转载 作者:行者123 更新时间:2023-11-30 22:03:20 25 4
gpt4 key购买 nike

您好,我正在训练 PyTorch 模型并发生此错误:

----> 5 代表 i,枚举中的数据(trainloader, 0):

类型错误:“torch.Size”对象无法解释为整数

不确定这个错误意味着什么。

您可以在这里找到我的代码:

model.train()
for epoch in range(10):
running_loss = 0

for i, data in enumerate(trainloader, 0):

inputs, labels = data

optimizer.zero_grad()

outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

if i % 2000 == 0:
print (loss.item())
running_loss += loss.item()
if i % 1000 == 0:
print ('[%d, %5d] loss: %.3f' % (epoch, i, running_loss/ 1000))
running_loss = 0

torch.save(model, 'FeatureNet.pkl')
<小时/>

更新

这是 DataLoader 的代码块。我正在使用自定义的数据加载器和数据集,其中 x 是大小为 (1025, 16) 的图片,y 是用于分类的 one-hot 编码向量。

x_train.shape = (1100, 1025, 16)

y_train.shape = (1100, 10)

clean_dir = '/home/tk/Documents/clean/' 
mix_dir = '/home/tk/Documents/mix/'
clean_label_dir = '/home/tk/Documents/clean_labels/'
mix_label_dir = '/home/tk/Documents/mix_labels/'

class MSourceDataSet(Dataset):

def __init__(self, clean_dir, mix_dir, clean_label_dir, mix_label_dir):

with open(clean_dir + 'clean0.json') as f:
clean0 = torch.Tensor(json.load(f))

with open(mix_dir + 'mix0.json') as f:
mix0 = torch.Tensor(json.load(f))

with open(clean_label_dir + 'clean_label0.json') as f:
clean_label0 = torch.Tensor(json.load(f))


with open(mix_label_dir + 'mix_label0.json') as f:
mix_label0 = torch.Tensor(json.load(f))


self.spec = torch.cat([clean0, mix0], 0)
self.label = torch.cat([clean_label0, mix_label0], 0)

def __len__(self):
return self.spec.shape


def __getitem__(self, index):

spec = self.spec[index]
label = self.label[index]
return spec, label

获取项目

a, b = trainset.__getitem__(1000)
print (a.shape)
print (b.shape)

a.shape = torch.Size([1025, 16]);b.shape = torch.Size([10])

错误消息

---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-9-3bd71e5c00e1> in <module>()
3 running_loss = 0
4
----> 5 for i, data in enumerate(trainloader, 0):
6
7 inputs, labels = data

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
311 def __next__(self):
312 if self.num_workers == 0: # same-process loading
--> 313 indices = next(self.sample_iter) # may raise StopIteration
314 batch = self.collate_fn([self.dataset[i] for i in indices])
315 if self.pin_memory:

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
136 def __iter__(self):
137 batch = []
--> 138 for idx in self.sampler:
139 batch.append(idx)
140 if len(batch) == self.batch_size:

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
32
33 def __iter__(self):
---> 34 return iter(range(len(self.data_source)))
35
36 def __len__(self):

TypeError: 'torch.Size' object cannot be interpreted as an integer

最佳答案

您的问题是 __len__ 函数。您不能使用shape作为返回值。

下面是一个示例:

import torch
class Foo:
def __init__(self, data):
self.data = data
def __len__(self):
return self.data.shape

myFoo = Foo(data=torch.rand(10, 20))
print(len(myFoo))

会引发完全相同的错误:

---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-285-e97aace2f622> in <module>
7
8 myFoo = Foo(data=torch.rand(10, 20))
----> 9 print(len(myFoo))

TypeError: 'torch.Size' object cannot be interpreted as an integer

由于 shape 代表一个 torch.Size 元组:

print(myFoo.data.shape)

输出:

torch.Size([10, 20])

所以你必须决定要将哪个维度交给__len__,例如第一个维度:

import torch
class Foo:
def __init__(self, data):
self.data = data
def __len__(self):
return self.data.shape[0] # choosing first dimension for len

myFoo = Foo(data=torch.rand(10, 20))
print(len(myFoo))
# prints 10

工作正常并返回10。当然,您也可以选择输入的任何其他维度,但您必须选择一个。

因此,在您的 MSourceDataSet 代码中,您必须将 __len__ 函数更改为:

def __len__(self):
return self.spec.shape[0] # as said of course you can also choose other dimensions

这应该可以解决您的问题。

关于python - Pytorch-类型错误 : 'torch.Size' object cannot be interpreted as an integer,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53588623/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com