gpt4 book ai didi

python - Pytorch 数据生成器,用于从许多 3D 立方体中提取 2D 图像

转载 作者:行者123 更新时间:2023-12-04 12:55:43 26 4
gpt4 key购买 nike

我正在努力在 PyTorch 中创建数据生成器,以从 .dat 中保存的许多 3D 立方体中提取 2D 图像。格式
一共有200每个具有 128*128*128 的 3D 立方体形状。现在我想从所有这些立方体中沿长度和宽度提取 2D 图像。
例如,a是一个大小为 128*128*128 的立方体
所以我想沿长度提取所有 2D 图像,即 [:, i, :]这将使我沿长度获得 128 个 2D 图像,同样我想沿宽度提取,即 [:, :, i] ,这将给我 128 个沿宽度的 2D 图像。因此,我从 1 个 3D 立方体中总共得到 256 个 2D 图像,我想对所有 200 个立方体重复整个过程,给我 51200 个 2D 图像。
到目前为止,我已经尝试了一个非常基本的实现,它运行良好,但需要大约 10 分钟才能运行。我希望你们帮助我创建一个更优化的实现,同时考虑到时间和空间的复杂性。现在我目前的方法有 O(n2) 的时间复杂度,我们可以进一步降低时间复杂度吗?
我在当前的实现下面提供

from os.path import join as pjoin
import torch
import numpy as np
import os
from tqdm import tqdm
from torch.utils import data


class DataGenerator(data.Dataset):

def __init__(self, is_transform=True, augmentations=None):

self.is_transform = is_transform
self.augmentations = augmentations
self.dim = (128, 128, 128)

seismicSections = [] #Input
faultSections = [] #Ground Truth
for fileName in tqdm(os.listdir(pjoin('train', 'seis')), total = len(os.listdir(pjoin('train', 'seis')))):
unrolledVolSeismic = np.fromfile(pjoin('train', 'seis', fileName), dtype = np.single) #dat file contains unrolled cube, we need to reshape it
reshapedVolSeismic = np.transpose(unrolledVolSeismic.reshape(self.dim)) #need to transpose the axis to get height axis at axis = 0, while length (axis = 1), and width(axis = 2)

unrolledVolFault = np.fromfile(pjoin('train', 'fault', fileName),dtype=np.single)
reshapedVolFault = np.transpose(unrolledVolFault.reshape(self.dim))

for idx in range(reshapedVolSeismic.shape[2]):
seismicSections.append(reshapedVolSeismic[:, :, idx])
faultSections.append(reshapedVolFault[:, :, idx])

for idx in range(reshapedVolSeismic.shape[1]):
seismicSections.append(reshapedVolSeismic[:, idx, :])
faultSections.append(reshapedVolFault[:, idx, :])

self.seismicSections = seismicSections
self.faultSections = faultSections

def __len__(self):
return len(self.seismicSections)

def __getitem__(self, index):

X = self.seismicSections[index]
Y = self.faultSections[index]

return X, Y
请帮忙!!!

最佳答案

为什么不只将 3D 数据存储在 mem 中,而让 __getitem__方法即时“切片”它?

class CachedVolumeDataset(Dataset):
def __init__(self, ...):
super(...)
self._volumes_x = # a list of 200 128x128x128 volumes
self._volumes_y = # a list of 200 128x128x128 volumes

def __len__(self):
return len(self._volumes_x) * (128 + 128)

def __getitem__(self, index):
# extract volume index from general index:
vidx = index // (128 + 128)
# extract slice index
sidx = index % (128 + 128)
if sidx < 128:
# first dim
x = self._volumes_x[vidx][:, :, sidx]
y = self._volumes_y[vidx][:, :, sidx]
else:
sidx -= 128
# second dim
x = self._volumes_x[vidx][:, sidx, :]
y = self._volumes_y[vidx][:, sidx, :]
return torch.squeeze(x), torch.squeeze(y)

关于python - Pytorch 数据生成器,用于从许多 3D 立方体中提取 2D 图像,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67956318/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com