gpt4 book ai didi

python - 用于读取大型 parquet/csv 文件的 Pytorch Dataloader

转载 作者:行者123 更新时间:2023-12-05 04:48:13 44 4
gpt4 key购买 nike

我试图让 Pytorch 训练单个 Parquet 文件的记录,而不必一次读取内存中的整个文件,因为它不适合内存。由于文件是远程存储的,我宁愿将它作为一个文件保存,因为对许多文件使用 IO 进行训练非常昂贵。当我想在 DataLoader 中指定批处理数时,如何使用 Pytorch 的 IterableDatasetDataset 在训练期间读取较小的文件 block ?我知道 map 样式的 Dataset 在这种情况下不起作用,因为我需要一个文件中的所有内容,而不是读取每个文件的索引。

我设法使用 tfio.IODatasettf.data.Dataset 在 Tensorflow 中实现了这一点,但我找不到在 Pytorch 中实现它的等效方法。

最佳答案

我找到了使用 torch.utils.data.Dataset 的解决方法,但数据必须事先使用 dask 进行操作,这样每个分区都是一个用户,存储为自己的 Parquet 文件,但可以以后只读一次。在下面的代码中,标签和数据分别存储用于多元时间序列分类问题(但也可以很容易地适应其他任务):

import dask.dataframe as dd
import pandas as pd
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, IterableDataset, Dataset

# Breakdown file
raw_ddf = dd.read_parquet("data.parquet") # Read huge file using dask
raw_ddf = raw_ddf.set_index("userid") # set the userid as index
userids = raw_ddf.index.unique().compute().values.tolist() # get a list of indices
new_ddf = raw_ddf.repartition(divisions = userids) # repartition by userids
new_ddf.to_parquet("my_folder") # this will save each user as its own parquet file within "my_folder"

# Dask to read the partitions
train_ddf = dd.read_parquet("my_folder/*.parquet") # read all files

# Read labels file
labels_df = pd.read("label.csv")
y_labels = np.array(labels_df["class"])

# Define the Dataset class
class UsersDataset(Dataset):
def __init__(self, dask_df, labels):
self.dask_df = dask_df
self.labels = labels

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):
X_df = self.dask_df.get_partition(idx).compute()
X = np.row_stack([X_df])
X_tensor = torch.tensor(X, dtype=torch.float32)
y = self.labels[idx]
y_tensor = torch.tensor(y, dtype=torch.long)
sample = (X_tensor, y_tensor)
return sample

# Create a Dataset object
user_dataset = UsersDataset(dask_df=ddf_train, labels = y_train)

# Create a DataLoader object
dataloader = DataLoader(user_dataset, batch_size=4, shuffle=True, num_workers=0)

# Print output of the first batch to ensure it works
for i_batch, sample_batched in enumerate(dataloader):
print("Batch number ", i_batch)
print(sample_batched[0]) # print X
print(sample_batched[1]) # print y

# stop after first batch.
if i_batch == 0:
break

我想知道在使用 >= 2 个工作人员读取数据时如何调整我的方法,而没有重复条目。非常感谢任何对此的见解。

关于python - 用于读取大型 parquet/csv 文件的 Pytorch Dataloader,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68199072/

44 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com