gpt4 book ai didi

python - 如何在 PyTorch 数据加载器中将 RGB 图像转换为灰度图像?

转载 作者:太空宇宙 更新时间:2023-11-03 15:41:10 24 4
gpt4 key购买 nike

我已经从 MNIST 数据集中以 .jpg 格式下载了一些示例图像。现在我正在加载这些图像以测试我的预训练模型。

# transforms to apply to the data
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# MNIST dataset
test_dataset = dataset.ImageFolder(root=DATA_PATH, transform=trans)

# Data loader
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

此处 DATA_PATH 包含一个带有示例图像的子文件夹。

这是我的网络定义

# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.network2D = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.network1D = nn.Sequential(
nn.Dropout(),
nn.Linear(7 * 7 * 64, 1000),
nn.Linear(1000, 10))

def forward(self, x):
out = self.network2D(x)
out = out.reshape(out.size(0), -1)
out = self.network1D(out)
return out

这是我的推理部分

# Test the model
model = torch.load("mnist_weights_5.pth.tar")
model.eval()

for images, labels in test_loader:
outputs = model(images.cuda())

当我运行这段代码时,出现以下错误:

RuntimeError: Given groups=1, weight of size [32, 1, 5, 5], expected input[1, 3, 28, 28] to have 1 channels, but got 3 channels instead

我知道图像是作为 3 channel (RGB) 加载的。那么如何在 dataloader 中将它们转换为单 channel ?

更新:我更改了 transforms 以包含 Grayscale 选项

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Grayscale(num_output_channels=1)])

但是现在我得到了这个错误

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

最佳答案

当使用 ImageFolder 类并且没有自定义加载器时,pytorch 使用 PIL 加载图像并将其转换为 RGB。如果 torchvision 图像后端是 PIL,则默认加载程序:

def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')

您可以在转换中使用torchvision 的灰度 功能。它将 3 channel RGB 图像转换为 1 channel 灰度。在 here 了解更多信息

下面是示例代码,

import torchvision as tv
import numpy as np
import torch.utils.data as data
dataDir = 'D:\\general\\ML_DL\\datasets\\CIFAR'
trainTransform = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainSet = tv.datasets.CIFAR10(dataDir, train=True, download=False, transform=trainTransform)
dataloader = data.DataLoader(trainSet, batch_size=1, shuffle=False, num_workers=0)
images, labels = iter(dataloader).next()
print (images.size())

关于python - 如何在 PyTorch 数据加载器中将 RGB 图像转换为灰度图像?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52439364/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com