- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章PyTorch实现重写/改写Dataset并载入Dataloader由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
前言 。
众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:
1
2
3
4
|
# 下载并存放数据集
train_dataset
=
torchvision.datasets.CIFAR10(root
=
"数据集存放位置"
,download
=
True
)
# load数据
train_loader
=
torch.utils.data.DataLoader(dataset
=
train_dataset)
|
但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?
我们可以通过改写torch.utils.data.Dataset中的__getitem__和__len__来载入我们自己的数据集。 __getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数).
改写 。
采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验).
1
2
3
4
5
6
7
8
9
|
import
os
import
torch
from
skimage
import
io, transform
import
numpy as np
import
matplotlib.pyplot as plt
from
torch.utils.data
import
Dataset, DataLoader
from
torchvision
import
transforms, utils
plt.ion()
# interactive mode
|
torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
class
ImageLoader(Dataset):
def
__init__(
self
, file_path, transform
=
None
):
super
(ImageLoader,
self
).__init__()
self
.file_path
=
file_path
self
.transform
=
transform
# 对输入图像进行预处理,这里并没有做,预设为None
self
.image_names
=
os.listdir(
self
.file_path)
# 文件名的列表
def
__getitem__(
self
,idx):
image
=
self
.image_names[idx]
image
=
io.imread(os.path.join(
self
.file_path,image))
# if self.transform:
# image= self.transform(image)
return
image
def
__len__(
self
):
return
len
(
self
.image_names)
# 设置自己存放的数据集位置,并plot展示
imageloader
=
ImageLoader(file_path
=
"D:\\Projects\\datasets\\faces\\"
)
# imageloader.__len__() # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(
0
))
# 以图像形式展示
plt.show()
|
得到的图片输出:
得到的数据输出,:
1
2
3
4
5
6
7
8
9
10
11
|
array([[[
66
,
59
,
53
],
[
66
,
59
,
53
],
[
66
,
59
,
53
],
...,
[
59
,
54
,
48
],
[
59
,
54
,
48
],
[
59
,
54
,
48
]],
...,
[
153
,
141
,
129
],
[
158
,
146
,
134
],
[
158
,
146
,
134
]]], dtype
=
uint8)
|
上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:
1
2
3
|
# 直接改成pytorch中的tensor下的float格式
# 也可以用numpy的改成普通的float格式
to_float
=
torch.from_numpy(imageloader.__getitem__(
0
)).
float
()
|
改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。 下面的代码可以试着运行一下,产生的是一模一样的图片结果.
1
2
3
4
|
train_loader
=
torch.utils.data.DataLoader(dataset
=
imageloader)
train_loader.dataset[
0
]
plt.imshow(train_loader.dataset[
0
])
plt.show()
|
到此这篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就介绍到这了,更多相关PyTorch重写/改写Dataset 内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。
原文链接:https://blog.csdn.net/qq_38372240/article/details/107322677 。
最后此篇关于PyTorch实现重写/改写Dataset并载入Dataloader的文章就讲到这里了,如果你想了解更多关于PyTorch实现重写/改写Dataset并载入Dataloader的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我一直在尝试通过一些我为自己开发的练习题重新认识 SQL,但正在努力寻找更好的方法来解决以下问题: 播放列表 id title 1 Title1 2 Title2 播放列表剪辑 id pl
我有一个很大的制表符分隔文件,如下所示: chr1 9507728 9517729 0 chr1 9507728 9517729 5S_rRNA chr1 9537731 954
我经常想编辑提交消息,而不必从上次提交中重新选择文件集。 git commit file1.c file2.c 提交消息中的意外拼写错误。 git commit file1.c file2.c --a
如何编辑或改写 merge 提交的消息? git commit --amend 如果它是最后一次提交 (HEAD) 则可以工作,但是如果它在 HEAD 之前呢? git rebase -i HEAD~
我在 flutter 中有以下声明。 weight是来自 _weightController 的文本,即 _weightController.text int.parse(weight).toStri
我是一名优秀的程序员,十分优秀!