gpt4 book ai didi

Python:从给定的数据集生成一个唯一的批处理

转载 作者:行者123 更新时间:2023-12-05 05:52:34 27 4
gpt4 key购买 nike

我正在应用 CNN 对给定数据集进行分类。

我的功能:

def batch_generator(dataset, input_shape = (256, 256), batch_size = 32):
dataset_images = []
dataset_labels = []
for i in range(0, len(dataset)):
dataset_images.append(cv2.resize(cv2.imread(dataset[i], cv2.IMREAD_COLOR),
input_shape, interpolation = cv2.INTER_AREA))
dataset_labels.append(labels[dataset[i].split('/')[-2]])
return dataset_images, dataset_labels

这个函数应该为每个时期调用,它应该返回一个大小为“batch_size”的唯一批处理,其中包含 dataset_images(每个图像为 256x256)和来自标签字典的相应 dataset_label。

输入“数据集”包含所有图像的路径,因此我打开它们并将它们的大小调整为 256x256。有人可以帮我添加此代码以便返回所需的批处理吗?

最佳答案

正如@jodag 所建议的,使用 DataLoaders 是个好主意。

我有一段用于 Pytorch 中的一些 CNN

from torch.utils.data import Dataset, DataLoader
import torch
class Data(Dataset):
"""
Constructs a Dataset to be parsed into a DataLoader
"""
def __init__(self,X,y):
X = torch.from_numpy(X).float()

#Transpose to fit dimensions of my network
X = torch.transpose(X,1,2)

y = torch.from_numpy(y).float()
self.X,self.y = X,y

def __getitem__(self, i):
return self.X[i],self.y[i]

def __len__(self):
return self.X.shape[0]

def create_data_loader(X,y,batch_size,**kwargs):
"""
Creates a data-loader for the data X and y

params:
-------

X: np.array
- numpy array of size "n" x k where n is samples an "k" is number of features

y: np.array
- numpy array of sie "n"

batch_size: int
- Take a wild guess, dumbass

kwargs:
- Additional keyword-arguments for "DataLoader"

return
------

dl: torch.utils.data.DataLoader object
"""

data = Data(X, y)

dl = DataLoader(data, batch_size=batch_size,num_workers=0,**kwargs)
return dl

这样使用;

from create_data_loader import create_data_loader

train_data_loader= create_data_loader(X_train,y_train,batch_size=32) #Note, it has "shuffle=True" as default!
val_data_loader= create_data_loader(X_val,y_val,batch_size=32,shuffle=False) #If you want to keep index'es in the same order for e.g cross-validate


for x_train, y_train in train_data_loader:
logit = net(x_train,y_train)
.
.
net.eval()
for x_val,y_val in val_data_loader:
logit = net(x_val,y_val)
classes_pred = logit.argmax(axis=1)
print(f"Val accuracy: {(y_val==classes_pred).mean()}")

关于Python:从给定的数据集生成一个唯一的批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70114892/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com