gpt4 book ai didi

python - tf.Variable 具有来自输入占位符的动态形状

转载 作者:太空宇宙 更新时间:2023-11-03 15:42:54 25 4
gpt4 key购买 nike

我正在尝试构建一个网络,我的要求是隐藏单元形状应根据输入形状,以便用户可以提供任意长度的输入。我想做的是:

import tensorflow as tf

n_hidden_1=100
input_ = tf.placeholder(name='input_data',shape=[None,None],dtype=tf.float32)
variable_data = tf.Variable(tf.random_normal([tf.shape(input_)[1], n_hidden_1]))

bias_ = tf.Variable(tf.random_normal([n_hidden_1]))

output = tf.matmul(input_,variable_data)+bias_

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output,feed_dict={input_:[[1,2,3,4,5]]}))

我要

variable_data = tf.Variable(tf.random_normal([tf.shape(input_)[1], n_hidden_​​1]))

来自输入占位符的形状,但它给出了一个错误。这就是我尝试这种方法的原因:

variable_data = tf.Variable(tf.random_normal([tf.shape(input_)[1], n_hidden_1]),validate_shape=False)

但是我又遇到了另一个错误:

InvalidArgumentError: You must feed a value for placeholder tensor 'input_data_7' with dtype float and shape [?,?]
[[Node: input_data_7 = Placeholder[dtype=DT_FLOAT, shape=[?,?], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

如何从变量形状的占位符中获取形状整数。

最佳答案

tf.Variable 不能真正具有动态形状,因为它对他们的目的没有意义。在 tensorflow 中,变量意味着优化的模型参数,通常是通过 SGD。在优化期间,您通常假设您的成本函数及其定义的空间在每一步都不会发生变化。不这样做——尤其是以随机的、非预定的方式——会导致定义不明确的优化问题。

但是,您可以做的是拥有一个变量,当第一个样本出现时,我们在运行时推迟其(静态)形状的定义。您就快完成了:您只需要将示例提供给 global_variables_initializer,它将变量初始化为正确的形状和值:

import tensorflow as tf


n_hidden_1=100
input_ = tf.placeholder(name='input_data',shape=[None,None],dtype=tf.float32)
variable_data = tf.Variable(tf.random_normal([tf.shape(input_)[1], n_hidden_1]), validate_shape=False)

bias_ = tf.Variable(tf.random_normal([n_hidden_1]))

output = tf.matmul(input_,variable_data)+bias_

with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), <b>feed_dict={input_:[[1,2,3,4,5]]}</b>)
print(sess.run(output, feed_dict={input_:[[1,2,3,4,5]]}))


关于python - tf.Variable 具有来自输入占位符的动态形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51552199/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com