gpt4 book ai didi

Pytorch DataLoader shuffle验证方式

转载 作者:qq735679552 更新时间:2022-09-27 22:32:09 32 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章Pytorch DataLoader shuffle验证方式由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

shuffle = False时,不打乱数据顺序 。

shuffle = True,随机打乱 。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
import h5py
import torch
from torch.utils.data import DataLoader, Dataset 
h5f = h5py. File ( 'train.h5' , 'w' );
data1 = np.array([[ 1 , 2 , 3 ],
                [ 2 , 5 , 6 ],
               [ 3 , 5 , 6 ],
               [ 4 , 5 , 6 ]])
data2 = np.array([[ 1 , 1 , 1 ],
                    [ 1 , 2 , 6 ],
                   [ 1 , 3 , 6 ],
                   [ 1 , 4 , 6 ]])
h5f.create_dataset( str ( 'data' ), data = data1)
h5f.create_dataset( str ( 'label' ), data = data2)
class Dataset(Dataset):
     def __init__( self ):
         h5f = h5py. File ( 'train.h5' , 'r' )
         self .data = h5f[ 'data' ]
         self .label = h5f[ 'label' ]
     def __getitem__( self , index):
         data = torch.from_numpy( self .data[index])
         label = torch.from_numpy( self .label[index])
         return data, label
 
     def __len__( self ):
         assert self .data.shape[ 0 ] = = self .label.shape[ 0 ], "wrong data length"
         return self .data.shape[ 0 ]
 
dataset_train = Dataset()
loader_train = DataLoader(dataset = dataset_train,
                            batch_size = 2 ,
                            shuffle = True )
 
for i, data in enumerate (loader_train):
     train_data, label = data
     print (train_data)

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的.

数据变换共有以下内容 。

?
1
2
3
4
5
composed = transforms.Compose([transforms.Resize(( 448 , 448 )), #  resize
                                transforms.RandomCrop( 300 ), # random crop
                                transforms.ToTensor(),
                                transforms.Normalize(mean = [ 0.5 , 0.5 , 0.5 ],  # normalize
                                                     std = [ 0.5 , 0.5 , 0.5 ])])

简单的数据读取类, 进返回PIL格式的image

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class MyDataset(data.Dataset):   
     def __init__( self , labels_file, root_dir, transform = None ):
         with open (labels_file) as csvfile:
             self .labels_file = list (csv.reader(csvfile))
         self .root_dir = root_dir
         self .transform = transform
        
     def __len__( self ):
         return len ( self .labels_file)
    
     def __getitem__( self , idx):
         im_name = os.path.join(root_dir, self .labels_file[idx][ 0 ])
         im = Image. open (im_name)
        
         if self .transform:
             im = self .transform(im)
            
         return im

下面是主程序 。

?
1
2
3
4
5
6
7
8
9
10
11
labels_file = "F:/test_temp/labels.csv"
root_dir = "F:/test_temp"
dataset_transform = MyDataset(labels_file, root_dir, transform = composed)
dataloader = data.DataLoader(dataset_transform, batch_size = 1 , shuffle = False )
"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """
for eopch in range ( 2 ):
     plt.figure(figsize = ( 6 , 6 ))
     for ind, i in enumerate (dataloader):
         a = i[ 0 , :, :, :].numpy().transpose(( 1 , 2 , 0 ))
         plt.subplot( 1 , 3 , ind + 1 )
         plt.imshow(a)

Pytorch DataLoader shuffle验证方式

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增 。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持我.

原文链接:https://blog.csdn.net/qq_35752161/article/details/110875040 。

最后此篇关于Pytorch DataLoader shuffle验证方式的文章就讲到这里了,如果你想了解更多关于Pytorch DataLoader shuffle验证方式的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com