gpt4 book ai didi

tensorflow - TensorFlow `py_func` 的输出具有未知的等级/形状

转载 作者:行者123 更新时间:2023-12-02 22:30:32 28 4
gpt4 key购买 nike

我正在尝试在 TensorFlow 中创建一个简单的神经网络。唯一棘手的部分是我有一个使用 py_func 实现的自定义操作。 。当我传递 py_func 的输出时到 Dense层,TensorFlow 提示排名应该是已知的。具体错误为:

ValueError: Inputs to `Dense` should have known rank.

当我传递数据时,我不知道如何保留数据的形状 py_func 。我的问题是如何获得正确的形状?下面我举一个简单的例子来说明这个问题。

def my_func(x):
return np.sinh(x).astype('float32')

inp = tf.convert_to_tensor(np.arange(5))
y = tf.py_func(my_func, [inp], tf.float32, False)

with tf.Session() as sess:
with sess.as_default():
print(inp.shape)
print(inp.eval())
print(y.shape)
print(y.eval())

此代码段的输出是:

(5,)
[0 1 2 3 4]
<unknown>
[ 0.
1.17520118 3.62686038 10.01787472 27.28991699]

为什么是y.shape <unknown> ?我想要的形状是 (5,)inp 相同。谢谢!

最佳答案

由于py_func可以执行任意Python代码并输出任何内容,TensorFlow无法计算出形状(需要分析函数体的Python代码)您可以手动给出形状

y.set_shape(inp.get_shape())

关于tensorflow - TensorFlow `py_func` 的输出具有未知的等级/形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42590431/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com