gpt4 book ai didi

python - torch 类型错误: '<' not supported between instances of 'Example' and 'Example' when referring to iterator

转载 作者:行者123 更新时间:2023-12-04 15:41:45 28 4
gpt4 key购买 nike

我正在尝试使用自己的数据集根据 https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/5%20-%20Multi-class%20Sentiment%20Analysis.ipynb 对文本进行分类.我的数据集是句子的 csv 和与之关联的类。有 6 个不同的类:

sent                      class
'the fox is brown' animal
'the house is big' object
'one water is drinkable' water
...

运行时:

N_EPOCHS = 5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

start_time = time.time()
print(start_time)
train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
#print(train_loss.type())
#print(train_acc.type())
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)

end_time = time.time()

epoch_mins, epoch_secs = epoch_time(start_time, end_time)

if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut5-model.pt')

print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')

,我得到了如下错误

1566460706.977012
epoch_loss
<torchtext.data.iterator.BucketIterator object at 0x000001FABE907E80>
TypeError: '<' not supported between instances of 'Example' and 'Example'

指向:

TypeError                                 Traceback (most recent call last)
<ipython-input-22-19e8a7eb204e> in <module>()
10 #print(train_loss.type())
11 #print(train_acc.type())
---> 12 valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
13
14 end_time = time.time()

<ipython-input-21-83b02f99bca7> in evaluate(model, iterator, criterion)
9 print('epoch_loss')
10 print(iterator)
---> 11 for batch in iterator:
12 print('batch')
13 predictions = model(batch.text)

我是pytorch的新手,所以只添加了一行来标识迭代器数据类型并得到:

<torchtext.data.iterator.BucketIterator object at 0x000001FABE907E80>

我试图确定以下 https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py 之后的特定属性无济于事。

如有任何建议,我们将不胜感激。

评估(迭代器所在的位置)方法的代码是:

def evaluate(model, iterator, criterion):

epoch_loss = 0
epoch_acc = 0

model.eval()

with torch.no_grad():
print('epoch_loss')
print(iterator)
for batch in iterator:
print('batch')
predictions = model(batch.text)

loss = criterion(predictions, batch.label)

acc = categorical_accuracy(predictions, batch.label)

epoch_loss += loss.item()
epoch_acc += acc.item()

return epoch_loss / len(iterator), epoch_acc / len(iterator)

最佳答案

我遇到了类似的问题,并通过在创建迭代器时使用 sort_keysort_within_batch 解决了这个问题。

train_iterator, valid_iterator = BucketIterator.splits(
(train, valid),
batch_size = BATCH_SIZE,
sort_key = lambda x: len(x.sent),
sort_within_batch=True,
device = device)

关于python - torch 类型错误: '<' not supported between instances of 'Example' and 'Example' when referring to iterator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57605217/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com