gpt4 book ai didi

python - 带有数据生成器的 keras VAE

转载 作者:行者123 更新时间:2023-12-05 04:54:21 25 4
gpt4 key购买 nike

我正在构建一个加载巨大数据集的 VAE。输入数据是维度为 (batch_size, 48, 48, 48) 的 3D 二进制体素数据。为了在训练中一个一个地加载数据,我构建了一个生成器,如下所示

class DataGenerator(keras.utils.Sequence):

def __init__(self, x_set, y_set, batch_size):
self.x = x_set # path for each dataset
self.y = y_set
self.batch_size = batch_size

def __len__(self):
return len(self.x) #batch size = data[0]

def __getitem__(self, idx):
batch_x = self.x[idx]
return np.load(batch_x).astype("float32"), None

在我尝试训练模型后,我收到如下错误消息:

NotImplementedError: When subclassing the `Model` class, you should implement a `call` method.

但再次尝试后,不知何故模型正在运行。谁能帮我解决这个问题?

这里还有另一个问题,因为这不是分类问题而且没有标签,我只需要放置一些没有 y 值的 x_test 数据集进行验证,但我收到如下错误:


~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
857
858 # Run validation.
--> 859 if validation_data and self._should_eval(epoch, validation_freq):
860 val_x, val_y, val_sample_weight = (
861 data_adapter.unpack_x_y_sample_weight(validation_data))

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

如果有人有经验对这种情况进行验证,请发表评论。

谢谢!!

最佳答案

调用方法错误

这个错误:

NotImplementedError: When subclassing the `Model` class, you should implement a `call` method.

表示您的模型子类未实现 call 方法,您需要它才能使用 keras 高阶方法(例如 fit)。当您以功能方式 model_instance(data) 威胁您的模型时,此方法只是一个包装器。在训练期间(使用 fit 方法时)它会在 epoch 结束时调用以计算准确性/验证/测试。

你可以通过将它定义到你的模型类中来修复:

class MyModel(keras.model.Model):
...

def call(self, data, training=False):
# your custom code when you call the model
# or just pass, you don't need this method
# for training
pass

def train_step(self, data):
# here is where you customazie your model training
# using backpropagation and other custom steps.
# Implementing this method is optional
...

无论如何,您的 DataGenerator 代码格式不适合 python(可能在复制粘贴时缺少制表符或空格)。还要确保缩进正确。

值错误

第二个错误与您可能将无效参数传递给 fit y 这一事实有关。您可以只传递 None 或忽略参数,但不要传递列表,否则会出错:

# DO
model.fit(x_dataset, batch_size=my_batch_size, others_parameter=...)
# or
model.fit(x_dataset, None, my_batch_size, other_parameters...)

# DONT
model.fit(x_dataset, y=[], batch_size=my_batch_size, other_parameters=...)
# or
model.fit(x_dataset, [], batch_size, other_parameters...)

验证的需要源于防止模型过度拟合和调整其超参数。

另请参阅 Keras documentation关于 VAE。

关于python - 带有数据生成器的 keras VAE,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65817032/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com