gpt4 book ai didi

Pytorch技巧:DataLoader的collate_fn参数使用详解

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 27 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章Pytorch技巧:DataLoader的collate_fn参数使用详解由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

DataLoader完整的参数表如下:

?
1
2
3
4
5
6
7
8
9
10
11
12
class torch.utils.data.DataLoader(
  dataset,
  batch_size = 1 ,
  shuffle = False ,
  sampler = None ,
  batch_sampler = None ,
  num_workers = 0 ,
  collate_fn = <function default_collate>,
  pin_memory = False ,
  drop_last = False ,
  timeout = 0 ,
  worker_init_fn = None )

DataLoader在数据集上提供单进程或多进程的迭代器 。

几个关键的参数意思:

- shuffle:设置为True的时候,每个世代都会打乱数据集 。

- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能 。

- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留 。

一个测试的例子 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.utils.data as Data
import numpy as np
 
test = np.array([ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ])
 
inputing = torch.tensor(np.array([test[i:i + 3 ] for i in range ( 10 )]))
target = torch.tensor(np.array([test[i:i + 1 ] for i in range ( 10 )]))
 
torch_dataset = Data.TensorDataset(inputing,target)
batch = 3
 
loader = Data.DataLoader(
  dataset = torch_dataset,
  batch_size = batch, # 批大小
  # 若dataset中的样本数不能被batch_size整除的话,最后剩余多少就使用多少
  collate_fn = lambda x:(
   torch.cat(
    [x[i][j].unsqueeze( 0 ) for i in range ( len (x))], 0
    ).unsqueeze( 0 ) for j in range ( len (x[ 0 ]))
   )
  )
 
for (i,j) in loader:
  print (i)
  print (j)

输出结果:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
tensor([[[ 0 , 1 , 2 ],
    [ 1 , 2 , 3 ],
    [ 2 , 3 , 4 ]]], dtype = torch.int32)
tensor([[[ 0 ],
    [ 1 ],
    [ 2 ]]], dtype = torch.int32)
tensor([[[ 3 , 4 , 5 ],
    [ 4 , 5 , 6 ],
    [ 5 , 6 , 7 ]]], dtype = torch.int32)
tensor([[[ 3 ],
    [ 4 ],
    [ 5 ]]], dtype = torch.int32)
tensor([[[ 6 , 7 , 8 ],
    [ 7 , 8 , 9 ],
    [ 8 , 9 , 10 ]]], dtype = torch.int32)
tensor([[[ 6 ],
    [ 7 ],
    [ 8 ]]], dtype = torch.int32)
tensor([[[ 9 , 10 , 11 ]]], dtype = torch.int32)
tensor([[[ 9 ]]], dtype = torch.int32)

如果不要collate_fn的值,输出变成 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
tensor([[ 0 , 1 , 2 ],
   [ 1 , 2 , 3 ],
   [ 2 , 3 , 4 ]], dtype = torch.int32)
tensor([[ 0 ],
   [ 1 ],
   [ 2 ]], dtype = torch.int32)
tensor([[ 3 , 4 , 5 ],
   [ 4 , 5 , 6 ],
   [ 5 , 6 , 7 ]], dtype = torch.int32)
tensor([[ 3 ],
   [ 4 ],
   [ 5 ]], dtype = torch.int32)
tensor([[ 6 , 7 , 8 ],
   [ 7 , 8 , 9 ],
   [ 8 , 9 , 10 ]], dtype = torch.int32)
tensor([[ 6 ],
   [ 7 ],
   [ 8 ]], dtype = torch.int32)
tensor([[ 9 , 10 , 11 ]], dtype = torch.int32)
tensor([[ 9 ]], dtype = torch.int32)

所以collate_fn就是使结果多一维.

看看collate_fn的值是什么意思。我们把它改为如下 。

?
1
collate_fn = lambda x:x

并输出 。

?
1
2
for i in loader:
  print (i)

得到结果 。

?
1
2
3
4
[(tensor([ 0 , 1 , 2 ], dtype = torch.int32), tensor([ 0 ], dtype = torch.int32)), (tensor([ 1 , 2 , 3 ], dtype = torch.int32), tensor([ 1 ], dtype = torch.int32)), (tensor([ 2 , 3 , 4 ], dtype = torch.int32), tensor([ 2 ], dtype = torch.int32))]
[(tensor([ 3 , 4 , 5 ], dtype = torch.int32), tensor([ 3 ], dtype = torch.int32)), (tensor([ 4 , 5 , 6 ], dtype = torch.int32), tensor([ 4 ], dtype = torch.int32)), (tensor([ 5 , 6 , 7 ], dtype = torch.int32), tensor([ 5 ], dtype = torch.int32))]
[(tensor([ 6 , 7 , 8 ], dtype = torch.int32), tensor([ 6 ], dtype = torch.int32)), (tensor([ 7 , 8 , 9 ], dtype = torch.int32), tensor([ 7 ], dtype = torch.int32)), (tensor([ 8 , 9 , 10 ], dtype = torch.int32), tensor([ 8 ], dtype = torch.int32))]
[(tensor([ 9 , 10 , 11 ], dtype = torch.int32), tensor([ 9 ], dtype = torch.int32))]

每个i都是一个列表,每个列表包含batch_size个元组,每个元组包含TensorDataset的单独数据。所以要将重新组合成每个batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我们的collate_fn:

?
1
2
3
4
5
collate_fn = lambda x:(
  torch.cat(
   [x[i][j].unsqueeze( 0 ) for i in range ( len (x))], 0
   ).unsqueeze( 0 ) for j in range ( len (x[ 0 ]))
  )

j取的是两个变量:input和target。i取的是batch_size。然后通过unsqueeze(0)方法在前面加一维。torch.cat(,0)将其打包起来。然后再通过unsqueeze(0)方法在前面加一维。 完成.

以上这篇Pytorch技巧:DataLoader的collate_fn参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我.

原文链接:https://blog.csdn.net/weixin_42028364/article/details/81675021 。

最后此篇关于Pytorch技巧:DataLoader的collate_fn参数使用详解的文章就讲到这里了,如果你想了解更多关于Pytorch技巧:DataLoader的collate_fn参数使用详解的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com