gpt4 book ai didi

python - 在训练期间将非张量参数传递给 Keras 模型/使用张量进行索引

转载 作者:行者123 更新时间:2023-12-04 08:13:30 26 4
gpt4 key购买 nike

我正在尝试训练一个在模型本身中包含数据增强的 Keras 模型。模型的输入是不同类的图像,模型应该为每个类生成一个增强模型,用于增强过程。我的代码大致如下:

from keras.models import Model
from keras.layers import Input
...further imports...

def get_main_model(input_shape, n_classes):
encoder_model = get_encoder_model()
input = Input(input_shape, name="input")
label_input = Input((1,), name="label_input")
aug_models = [get_augmentation_model() for i in range(n_classes)]
augmentation = aug_models[label_input](input)
x = encoder_model(input)
y = encoder_model(augmentation)
model = Model(inputs=[input, label_input], outputs=[x, y])
model.add_loss(custom_loss_function(x, y))
return model
然后,我想通过由图像数组(传递给 input )和相应的标签数组(传递给 label_input )组成的模型传递批量数据。但是,这不起作用,因为输入到 label_input 的任何内容都被 Tensorflow 转换为张量,并且不能用于以下索引。我试过的是以下内容:
  • augmentation = aug_models[int(label_input)](input) --> 不起作用
    因为 label_input is a tensor
  • augmentation = aug_models[tf.make_ndarray(label_input)](input) --> 转换不起作用(我猜是因为 label_input 是一个符号张量)
  • tf.gather(aug_models, label_input) --> 不起作用,因为操作的结果是一个 Keras 模型实例,Tensorflow 试图将其转换为张量(显然失败)

  • Tensorflow 中是否有任何技巧可以让我在训练期间将参数传递给模型,该参数不会转换为张量,或者我可以告诉模型选择哪个增强模型的不同方式?提前致谢!

    最佳答案

    input 的每个元素应用不同的增强张量(例如以 label_input 为条件),您需要:

  • 首先,为批次的每个元素计算每个可能的增强。
  • 其次,根据标签选择所需的增强。

  • 不幸的是,索引是不可能的,因为 inputlabel_input张量是多维的(例如,如果您要对批次的每个元素应用相同的增强,则可以使用任何条件 tensorflow 语句,例如 tf.case )。

    这是一个最小的工作示例,展示了如何实现这一目标:
    input = tf.ones((3, 1))  # Shape=(bs, 1)
    label_input = tf.constant([3, 2, 1]) # Shape=(bs, 1)
    aug_models = [lambda x: x, lambda x: x * 2, lambda x: x * 3, lambda x: x * 4]
    nb_classes = len(aug_models)

    augmented_data = tf.stack([aug_model(input) for aug_model in aug_models]) # Shape=(nb_classes, bs, 1)
    selector = tf.transpose(tf.one_hot(label_input, depth=nb_classes)) # Shape=(nb_classes, bs)
    augmentation = tf.reduce_sum(selector[..., None] * augmented_data, axis=0) # Shape=(bs, 1)
    print(augmentation)

    # prints:
    # tf.Tensor(
    # [[4.]
    # [3.]
    # [2.]], shape=(3, 1), dtype=float32)
    注意:您可能需要将这些操作包装到 Keras Lambda layer 中.

    关于python - 在训练期间将非张量参数传递给 Keras 模型/使用张量进行索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65829671/

    26 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com