gpt4 book ai didi

keras - input_tensors在tf.keras.models.clone_model中的作用

转载 作者:行者123 更新时间:2023-12-03 20:58:20 25 4
gpt4 key购买 nike

我正在尝试复制现有的 keras 模型。以下是我创建的示例代码,它似乎按预期工作。

model = CreateSimpleModel()
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])

model.summary()


model_cloned = tf.keras.models.clone_model(model)
model_cloned.set_weights(model.get_weights())

print(model(np.array([[1, 2]])))
print(model_cloned(np.array([[1, 2]])))

但是,如果我们查看有关 tf.keras.models.clone_model 的官方文档在接下来的页面中,有一个名为 input_tensors 的参数.

https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model

这个参数的作用我不是很确定。从上面的示例代码中,我不太明白为什么在某些情况下需要它。有人可以用一些例子来解释吗?

最佳答案

编辑:请不要做我在下面所做的事情。使用 GradCAM 查看第一个卷积层中使用的权重后,似乎 input_tensors 参数没有任何影响,尽管输入发生了变化,但 base_model 的所有克隆都具有相同的权重。
就我而言,我使用 tf.keras.models.clone_model 来克隆基本的预训练神经网络,以便我的所有多个输入都有自己的路径:

# inputs:
x1 = tf.keras.layers.Input(shape=(None, None, 3), name="x1")
x2 = tf.keras.layers.Input(shape=(None, None, 3), name="x2")
x3 = tf.keras.layers.Input(shape=(None, None, 3), name="x3")

# load base model:
base_model = tf.keras.applications.DenseNet169(input_tensor=x1, input_shape=(224, 224, 3), include_top=False, pooling='avg')

# create copies:
base_model2 = tf.keras.models.clone_model(base_model, input_tensors=x2)
base_model3 = tf.keras.models.clone_model(base_model, input_tensors=x3)

# you have to rename the layers in each model so there aren't any conflicts:
cnt = 0
for mod in [base_model1, base_model2, base_model3, base_model4, base_model5, base_model6]:
cnt += 1
for layer in mod.layers:
old_name = layer.name
layer._name = f"base_model{cnt}_{old_name}"

# this bit isn't necessary unless you want to access weights easily later on:
base1_out = base_model.output
base2_out = base_model2.output
base3_out = base_model3.output

# concatenate the outputs:
concatenated = tf.keras.layers.concatenate([base1_out, base2_out, base3_out], axis=-1)

# add dense layers if you want:
concat_dense = tf.keras.layers.Dense(2048)(concatenated)
out = tf.keras.layers.Dense(class_count, activation='softmax')(concat_dense)

tf.keras.models.Model(inputs=[x1, x2, x3], outputs=[out])
请注意,我的输入(以字典的形式)来自使用 TensorFlow 的 tf.data.Dataset 创建的符号张量。

关于keras - input_tensors在tf.keras.models.clone_model中的作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59705455/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com