gpt4 book ai didi

python - 对于多输入模型,tf.data.Dataset.from_tensor_slices 的替代方法是什么?

转载 作者:行者123 更新时间:2023-12-04 07:20:25 26 4
gpt4 key购买 nike

我正在尝试制作一个需要两个输入的多输入 Keras 模型。一个输入是图像,第二个输入是一维文本。我将图像的路径存储在数据框中,然后将图像附加到这样的列表中:

from tqdm import tqdm

train_images = []
for image_path in tqdm(train_df['paths']):
byte_file = tf.io.read_file(image_path)
img = tf.image.decode_png(byte_file)
train_images.append(img)
一维文本输入存储在列表中。对验证集和测试集重复此过程。然后我制作一个数据集,如下所示:
train_protein = tf.expand_dims(padded_train_protein_encode,axis=2)
training_dataset = tf.data.Dataset.from_tensor_slices(((train_protein, train_images), train_Labels))

training_dataset = training_dataset.batch(20)

val_protein = tf.expand_dims(padded_val_protein_encode, axis=2)
validation_dataset = tf.data.Dataset.from_tensor_slices(((val_protein, val_images), validation_Labels))
validation_dataset = validation_dataset.batch(20)

test_protein = tf.expand_dims(padded_test_protein_encode, axis=2)
test_dataset = tf.data.Dataset.from_tensor_slices(((test_protein, test_images), test_Labels))
test_dataset = test_dataset.batch(20)

我在 Google Colab 中运行它,即使使用高内存选项,程序也会由于内存不足而崩溃。解决这个问题的最佳方法是什么?
我已经研究了 tf.data.Dataset.from_generator 作为一个选项,但是当有两个输入时我无法弄清楚如何使它工作。任何人都可以帮忙吗?

最佳答案

这是一种相当常见的疼痛。如果您的数据集太大而无法加载到内存中,那么没有比数据生成器更好的方法了。来自 PyTorch,有 pythonic 类可以做到这一点,而不必使用 tf.data.Dataset.from_generator .子类化 tf.keras.utils.Sequence可能是一个优雅的选择。无法访问您的数据集,我无法验证,但这样的事情应该可行。__getitem__被称为每批。

class TfDataGenerator(tf.keras.utils.Sequence):
def __init__(self, filepaths, proteins, labels):
self.filepaths = np.array(filepaths)
self.proteins = np.array(proteins)
self.labels = labels

def __len__(self):
return len(self.filenames) // self.batch_size

def __getitem__(self, index):
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
return __generate_x(indexes), labels[indexes]

def __generate_x(self, indexes):
x_1 = np.empty((self.batch_size, *self.dim, self.n_channels))
x_2 = np.empty((self.batch_size, len(self.meta_features)))

for index in enumerate(indexes):
image = cv2.imread(self.filepaths[index])
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
x_1[num] = image.astype(np.float32)/255.
x_2[num] = self.proteins[index]

return [x_1, x_2]

def on_epoch_end(self):
self.indexes = np.arange(len(self.filenames))
if self.shuffle:
np.random.shuffle(self.indexes)
再次,一个非常粗略的例子,但希望它表明可以做什么。 Tensorflow 文档 here
过去这让我很头疼,所以希望这个答案有所帮助。

关于python - 对于多输入模型,tf.data.Dataset.from_tensor_slices 的替代方法是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68547097/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com