gpt4 book ai didi

python - TensorFlow 使用来自具有多个输出 : Cannot properly define shapes? 的生成器的数据集进行拟合

转载 作者:行者123 更新时间:2023-12-05 01:08:28 26 4
gpt4 key购买 nike

我正在尝试使用生成器将项目转换为具有多个输出的单个网络,但我似乎无法弄清楚如何在使用生成器时使多个输出正常工作。下面是一小段可验证的代码:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models

def generate_sample():
x = list("123456789")
y = list("2345")
while 1:
yield np.array(x).astype(np.float32),[np.array(y).astype(np.float32),np.array(y).astype(np.float32)]

dataset = tf.data.Dataset.from_generator(generate_sample,
output_signature=(
tf.TensorSpec(shape=(9,), dtype=tf.float32),
tf.TensorSpec(shape=(2,4), dtype=tf.float32)

))

dataset = dataset.batch(batch_size=32)

inputs = keras.Input(shape=(next(generate_sample())[0].shape))
x = layers.Dense(512, activation = "relu")(inputs)
x_outputs = layers.Dense(4, activation="relu", name="output")(x)
y_outputs = layers.Dense(4, activation="relu", name="output2")(x)

model = keras.Model(inputs=inputs, outputs=[x_outputs,y_outputs])
model.compile(loss="mse", optimizer = "adam", metrics=['accuracy'])
history = model.fit(dataset, epochs=1, steps_per_epoch=10, validation_data=dataset, validation_steps=5)

这会导致一个很长的错误,其最后部分是:

InvalidArgumentError: Incompatible shapes: [32,2,4] vs. [32,4]
[[node mean_squared_error/SquaredDifference (defined at:1) ]][Op:__inference_train_function_8957]

Function call stack: train_function

我已经尝试过使用 output_shapeoutput_signature 等,以我能想象到的各种方式 reshape 数据。无论如何,我仍然会遇到形状问题。

我是否在这里遗漏了一些明显的东西,或者在 fit 中使用生成器作为数据集的源有什么问题?例如,当我从内存中加载数据时,这样做没有问题。

最佳答案

模型的输出不是一个形状为(2,4)的Tensor,而是两个形状为(4)的Tensor。

您应该更改您的生成器函数以反射(reflect)这一点:

def generate_sample():
x = list("123456789")
y = list("2345")
while 1:
yield np.array(x).astype(np.float32),(np.array(y).astype(np.float32),np.array(y).astype(np.float32))

以及您的输出签名:

dataset = tf.data.Dataset.from_generator(generate_sample,
output_signature=(
tf.TensorSpec(shape=(9,), dtype=tf.float32),
(tf.TensorSpec(shape=(4,), dtype=tf.float32),
tf.TensorSpec(shape=(4,), dtype=tf.float32)),
))

请注意,生成器的输出是嵌套元组。

关于python - TensorFlow 使用来自具有多个输出 : Cannot properly define shapes? 的生成器的数据集进行拟合,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65998640/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com