gpt4 book ai didi

python - 输入张量 进入循环,形状为 (),但在一次迭代后形状为

转载 作者:行者123 更新时间:2023-12-05 06:10:49 25 4
gpt4 key购买 nike

我正在尝试使用 tf.function 保存模型在贪婪解码方法上。

代码已经过测试,并且按预期在急切模式(调试)下工作。但是,它在非急切执行中不起作用

该方法得到一个 namedtuple称为 Hyp看起来像这样:

Hyp = namedtuple(
'Hyp',
field_names='score, yseq, encoder_state, decoder_state, decoder_output'
)

while 循环是这样调用的:

_, hyp = tf.while_loop(
cond=condition_,
body=body_,
loop_vars=(tf.constant(0, dtype=tf.int32), hyp),
shape_invariants=(
tf.TensorShape([]),
tf.nest.map_structure(get_shape_invariants, hyp),
)
)

这是 body_ 的相关部分:

def body_(i_, hypothesis_: Hyp):

# [:] Collapsed some code ..

def update_from_next_id_():
return Hyp(
# Update values ..
)

# The only place where I generate a new hypothesis_ namedtuple
hypothesis_ = tf.cond(
tf.not_equal(next_id, blank),
true_fn=lambda: update_from_next_id_(),
false_fn=lambda: hypothesis_
)

return i_ + 1, hypothesis_

我得到的是 ValueError :

ValueError: Input tensor 'hypotheses:0' enters the loop with shape (), but has shape <unknown> after one iteration. To allow the shape to vary across iterations, use the shape_invariants argument of tf.while_loop to specify a less-specific shape.

这可能是什么问题?

下面是如何input_signaturetf.function 定义我想连载。

在这里,self.greedy_decode_impl是实际的实现 - 我知道这有点难看但是 self.greedy_decode这就是我所说的。

self.greedy_decode = tf.function(
self.greedy_decode_impl,
input_signature=(
tf.TensorSpec([1, None, self.config.encoder.lstm_units], dtype=tf.float32),
Hyp(
score=tf.TensorSpec([], dtype=tf.float32),
yseq=tf.TensorSpec([1, None], dtype=tf.int32),
encoder_state=tuple(
(tf.TensorSpec([1, lstm.units], dtype=tf.float32),
tf.TensorSpec([1, lstm.units], dtype=tf.float32))
for (lstm, _) in self.encoder_network.lstm_stack
),
decoder_state=tuple(
(tf.TensorSpec([1, lstm.units], dtype=tf.float32),
tf.TensorSpec([1, lstm.units], dtype=tf.float32))
for (lstm, _) in self.predict_network.lstm_stack
),
decoder_output=tf.TensorSpec([1, None, self.config.decoder.lstm_units], dtype=tf.float32)
),
)
)

执行greedy_decode_impl :

def greedy_decode_impl(self, encoder_outputs: tf.Tensor, hypotheses: Hyp, blank=0) -> Hyp:

hyp = hypotheses

encoder_outputs = encoder_outputs[0]

def condition_(i_, *_):
time_steps = tf.shape(encoder_outputs)[0]
return tf.less(i_, time_steps)

def body_(i_, hypothesis_: Hyp):

encoder_output_ = tf.reshape(encoder_outputs[i_], shape=(1, 1, -1))

join_out = self.join_network((encoder_output_, hypothesis_.decoder_output), training=False)

logits = tf.squeeze(tf.nn.log_softmax(tf.squeeze(join_out)))
next_id = tf.argmax(logits, output_type=tf.int32)
log_prob = logits[next_id]
next_id = tf.reshape(next_id, (1, 1))

def update_from_next_id_():
decoder_output_, decoder_state_ = self.predict_network(
next_id,
memory_states=hypothesis_.decoder_state,
training=False
)
return Hyp(
score=hypothesis_.score + log_prob,
yseq=tf.concat([hypothesis_.yseq, next_id], axis=0),
decoder_state=decoder_state_,
decoder_output=decoder_output_,
encoder_state=hypothesis_.encoder_state
)

hypothesis_ = tf.cond(
tf.not_equal(next_id, blank),
true_fn=lambda: update_from_next_id_(),
false_fn=lambda: hypothesis_
)

return i_ + 1, hypothesis_

_, hyp = tf.while_loop(
cond=condition_,
body=body_,
loop_vars=(tf.constant(0, dtype=tf.int32), hyp),
shape_invariants=(
tf.TensorShape([]),
tf.nest.map_structure(get_shape_invariants, hyp),
)
)

return hyp

为什么它在 eager-mode 下工作而不在 non-eager 模式下工作?

根据 tf.while_loop 的文档一个namedtuple应该没问题。


斐波那契例子

为了检查这是否适用于 namedtuple ,我已经使用类似的机制实现了斐波那契数列。为了包含条件,循环在到达步骤 n // 2 时停止附加新数字。 :

正如我们在下面看到的,该方法应该可以在没有 Python 副作用的情况下工作。

from collections import namedtuple

import tensorflow as tf

FibonacciStep = namedtuple('FibonacciStep', field_names='seq, prev_value')


def shape_list(x):
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]


def get_shape_invariants(tensor):
shapes = shape_list(tensor)
return tf.TensorShape([i if isinstance(i, int) else None for i in shapes])


def save_tflite(fp, concrete_fn):
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_fn])
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.optimizations = []
tflite_model = converter.convert()
with tf.io.gfile.GFile(fp, 'wb') as f:
f.write(tflite_model)


@tf.function(
input_signature=(
tf.TensorSpec([], dtype=tf.int32),
FibonacciStep(
seq=tf.TensorSpec([1, None], dtype=tf.int32),
prev_value=tf.TensorSpec([], dtype=tf.int32),
)
)
)
def fibonacci(n: tf.Tensor, fibo: FibonacciStep):

def cond_(i_, *args):
return tf.less(i_, n)

def body_(i_, fibo_: FibonacciStep):

prev_value = fibo_.seq[0, -1] + fibo_.prev_value

def append_value():
return FibonacciStep(
seq=tf.concat([fibo_.seq, tf.reshape(prev_value, shape=(1, 1))], axis=-1),
prev_value=fibo_.seq[0, -1]
)

fibo_ = tf.cond(
tf.less_equal(i_, n // 2),
true_fn=lambda: append_value(),
false_fn=lambda: fibo_
)

return i_ + 1, fibo_

_, fibo = tf.while_loop(
cond=cond_,
body=body_,
loop_vars=(0, fibo),
shape_invariants=(
tf.TensorShape([]),
tf.nest.map_structure(get_shape_invariants, fibo),
)
)

return fibo


def main():

n = tf.constant(10, dtype=tf.int32)
fibo = FibonacciStep(
seq=tf.constant([[0, 1]], dtype=tf.int32),
prev_value=tf.constant(0, dtype=tf.int32),
)

fibo = fibonacci(n, fibo=fibo)
fibo = fibonacci(n + 10, fibo=fibo)

fp = '/tmp/fibonacci.tflite'
concrete_fn = fibonacci.get_concrete_function()
save_tflite(fp, concrete_fn)

print(fibo.seq.numpy()[0].tolist())

print('All done.')


if __name__ == '__main__':
main()

输出:

[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584]
All done.

最佳答案

好吧,原来是这样

tf.concat([hypothesis_.yseq, next_id], axis=0),

应该是

tf.concat([hypothesis_.yseq, next_id], axis=-1),

公平地说,错误消息有点给了您在哪里查看的提示,但用“有帮助”来形容它太过分了。我通过连接错误的轴违反了 TensorSpec,仅此而已,但 Tensorflow 无法直接指向受影响的 Tensor(目前)。

关于python - 输入张量 <name> 进入循环,形状为 (),但在一次迭代后形状为 <unknown>,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64244431/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com