gpt4 book ai didi

python - 将标量输入 Tensorflow 2 模型的正确方法

转载 作者:行者123 更新时间:2023-12-04 17:20:23 24 4
gpt4 key购买 nike

在我的 Tensorflow 2 模型中,我希望批量大小是参数化的,这样我就可以动态构建具有适当批量大小的张量。我有以下代码:

batch_size_param = 128

tf_batch_size = tf.keras.Input(shape=(), name="tf_batch_size", dtype=tf.int32)
batch_indices = tf.range(0, tf_batch_size, 1)

md = tf.keras.Model(inputs={"tf_batch_size": tf_batch_size}, outputs=[batch_indices])
res = md(inputs={"tf_batch_size": batch_size_param})

代码在 tf.range 中抛出错误:

ValueError: Shape must be rank 0 but is rank 1
for 'limit' for '{{node Range}} = Range[Tidx=DT_INT32](Range/start, tf_batch_size, Range/delta)' with input shapes: [], [?], []

我认为问题在于 tf.keras.Input 会自动尝试在第一维扩展输入数组,因为它期望输入的部分形状没有批量大小和将根据输入数组的形状附加批量大小,在我的例子中是一个标量。我可以将标量值作为常量整数输入到 tf.range 中,但这一次,在编译模型图后我将无法更改它。

有趣的是,即使我也检查了文档,我也未能找到将标量输入 TF-2 模型的正确方法。那么,处理这种情况的最佳方法是什么?

最佳答案

不要使用 tf.keras.Input,只需通过子类化来定义模型。

import tensorflow as tf


class ScalarModel(tf.keras.Model):
def __init__(self):
super().__init__()

def call(self, x):
return tf.range(0, x, 1)


print(ScalarModel()(10))
# tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32)

关于python - 将标量输入 Tensorflow 2 模型的正确方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66546321/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com