作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我正在使用TensorDataset
从 numpy 数组创建数据集。
# convert numpy arrays to pytorch tensors
X_train = torch.stack([torch.from_numpy(np.array(i)) for i in X_train])
y_train = torch.stack([torch.from_numpy(np.array(i)) for i in y_train])
# reshape into [C, H, W]
X_train = X_train.reshape((-1, 1, 28, 28)).float()
# create dataset and dataloaders
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
如何将数据增强 ( transforms ) 应用于 TensorDataset
?
例如,使用 ImageFolder
,我可以将转换指定为其参数之一 torchvision.datasets.ImageFolder(root, transform=...)
。
根据this reply由 PyTorch 团队成员之一提出,默认情况下不支持。有没有其他方法可以做到这一点?
欢迎询问是否需要更多代码来解释问题。
最佳答案
默认情况下,TensorDataset
不支持转换。但我们可以创建自定义类来添加该选项。但是,正如我已经提到的,大多数转换都是为 PIL.Image
开发的。但无论如何,这是一个非常简单的 MNIST 示例,具有非常虚拟的变换。带有 MNIST 的 csv 文件 here .
代码:
import numpy as np
import torch
from torch.utils.data import Dataset, TensorDataset
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Import mnist dataset from cvs file and convert it to torch tensor
with open('mnist_train.csv', 'r') as f:
mnist_train = f.readlines()
# Images
X_train = np.array([[float(j) for j in i.strip().split(',')][1:] for i in mnist_train])
X_train = X_train.reshape((-1, 1, 28, 28))
X_train = torch.tensor(X_train)
# Labels
y_train = np.array([int(i[0]) for i in mnist_train])
y_train = y_train.reshape(y_train.shape[0], 1)
y_train = torch.tensor(y_train)
del mnist_train
class CustomTensorDataset(Dataset):
"""TensorDataset with support of transforms.
"""
def __init__(self, tensors, transform=None):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.transform = transform
def __getitem__(self, index):
x = self.tensors[0][index]
if self.transform:
x = self.transform(x)
y = self.tensors[1][index]
return x, y
def __len__(self):
return self.tensors[0].size(0)
def imshow(img, title=''):
"""Plot the image batch.
"""
plt.figure(figsize=(10, 10))
plt.title(title)
plt.imshow(np.transpose( img.numpy(), (1, 2, 0)), cmap='gray')
plt.show()
# Dataset w/o any tranformations
train_dataset_normal = CustomTensorDataset(tensors=(X_train, y_train), transform=None)
train_loader = torch.utils.data.DataLoader(train_dataset_normal, batch_size=16)
# iterate
for i, data in enumerate(train_loader):
x, y = data
imshow(torchvision.utils.make_grid(x, 4), title='Normal')
break # we need just one batch
# Let's add some transforms
# Dataset with flipping tranformations
def vflip(tensor):
"""Flips tensor vertically.
"""
tensor = tensor.flip(1)
return tensor
def hflip(tensor):
"""Flips tensor horizontally.
"""
tensor = tensor.flip(2)
return tensor
train_dataset_vf = CustomTensorDataset(tensors=(X_train, y_train), transform=vflip)
train_loader = torch.utils.data.DataLoader(train_dataset_vf, batch_size=16)
result = []
for i, data in enumerate(train_loader):
x, y = data
imshow(torchvision.utils.make_grid(x, 4), title='Vertical flip')
break
train_dataset_hf = CustomTensorDataset(tensors=(X_train, y_train), transform=hflip)
train_loader = torch.utils.data.DataLoader(train_dataset_hf, batch_size=16)
result = []
for i, data in enumerate(train_loader):
x, y = data
imshow(torchvision.utils.make_grid(x, 4), title='Horizontal flip')
break
输出:
关于python - PyTorch 在 TensorDataset 上进行转换,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55588201/
我正在使用TensorDataset从 numpy 数组创建数据集。 # convert numpy arrays to pytorch tensors X_train = torch.stack([
我是一名优秀的程序员,十分优秀!