gpt4 book ai didi

PyTorch实现重写/改写Dataset并载入Dataloader

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 30 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章PyTorch实现重写/改写Dataset并载入Dataloader由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

前言 。

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:

?
1
2
3
4
# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root = "数据集存放位置" ,download = True )
# load数据
train_loader = torch.utils.data.DataLoader(dataset = train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem__和__len__来载入我们自己的数据集。 __getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数).

改写 。

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验).

?
1
2
3
4
5
6
7
8
9
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
 
plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class ImageLoader(Dataset):
   def __init__( self , file_path, transform = None ):
     super (ImageLoader, self ).__init__()
     self .file_path = file_path
     self .transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
     self .image_names = os.listdir( self .file_path) # 文件名的列表
    
   def __getitem__( self ,idx):
     image = self .image_names[idx]
     image = io.imread(os.path.join( self .file_path,image))
#    if self.transform:
#       image= self.transform(image)
     return image
         
   def __len__( self ):
     return len ( self .image_names)
 
# 设置自己存放的数据集位置,并plot展示   
imageloader = ImageLoader(file_path = "D:\\Projects\\datasets\\faces\\" )
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__( 0 )) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:

?
1
2
3
4
5
6
7
8
9
10
11
array([[[ 66 , 59 , 53 ],
     [ 66 , 59 , 53 ],
     [ 66 , 59 , 53 ],
     ...,
     [ 59 , 54 , 48 ],
     [ 59 , 54 , 48 ],
     [ 59 , 54 , 48 ]],
     ...,
     [ 153 , 141 , 129 ],
     [ 158 , 146 , 134 ],
     [ 158 , 146 , 134 ]]], dtype = uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:

?
1
2
3
# 直接改成pytorch中的tensor下的float格式
# 也可以用numpy的改成普通的float格式
to_float = torch.from_numpy(imageloader.__getitem__( 0 )). float ()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。 下面的代码可以试着运行一下,产生的是一模一样的图片结果.

?
1
2
3
4
train_loader = torch.utils.data.DataLoader(dataset = imageloader)
train_loader.dataset[ 0 ]
plt.imshow(train_loader.dataset[ 0 ])
plt.show()

到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。

原文链接:https://blog.csdn.net/qq_38372240/article/details/107322677 。

最后此篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就讲到这里了,如果你想了解更多关于PyTorch实现重写/改写Dataset并载入Dataloader的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com