gpt4 book ai didi

Pytorch DataLoader 变长数据处理方式

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 30 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章Pytorch DataLoader 变长数据处理方式由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

关于Pytorch中怎么自定义Dataset数据集类、怎样使用DataLoader迭代加载数据,这篇官方文档已经说得很清楚了,这里就不在赘述.

现在的问题:有的时候,特别对于NLP任务来说,输入的数据可能不是定长的,比如多个句子的长度一般不会一致,这时候使用DataLoader加载数据时,不定长的句子会被胡乱切分,这肯定是不行的.

解决方法是重写DataLoader的collate_fn,具体方法如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 假如每一个样本为:
sample = {
     # 一个句子中各个词的id
     'token_list' : [ 5 , 2 , 4 , 1 , 9 , 8 ],
     # 结果y
     'label' : 5 ,
}
 
 
# 重写collate_fn函数,其输入为一个batch的sample数据
def collate_fn(batch):
     # 因为token_list是一个变长的数据,所以需要用一个list来装这个batch的token_list
   token_lists = [item[ 'token_list' ] for item in batch]
  
   # 每个label是一个int,我们把这个batch中的label也全取出来,重新组装
   labels = [item[ 'label' ] for item in batch]
   # 把labels转换成Tensor
   labels = torch.Tensor(labels)
   return {
     'token_list' : token_lists,
     'label' : labels,
   }
 
 
# 在使用DataLoader加载数据时,注意collate_fn参数传入的是重写的函数
DataLoader(trainset, batch_size = 4 , shuffle = True , num_workers = 4 , collate_fn = collate_fn)

使用以上方法,可以保证DataLoader能Load出一个batch的数据,load出来的东西就是重写的collate_fn函数最后return出来的字典.

以上这篇Pytorch DataLoader 变长数据处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.

原文链接:https://blog.csdn.net/HappyCtest/article/details/88872651 。

最后此篇关于Pytorch DataLoader 变长数据处理方式的文章就讲到这里了,如果你想了解更多关于Pytorch DataLoader 变长数据处理方式的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com