gpt4 book ai didi

python - 无法写入 TensorArray 索引,因为值形状与 TensorArray 的推断元素形状不兼容

转载 作者:行者123 更新时间:2023-12-01 08:28:35 31 4
gpt4 key购买 nike

我有一个字符串张量(名为句子),我想在其中获取其单词的嵌入:

sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)

我使用上面的代码对批处理中的所有句子应用字符串拆分。然后,我在单词表中应用查找来获取这些张量内每个单词的单词索引:

sentence = tf.map_fn(lambda x: tf.cast(word_table.lookup(x), tf.int32), sentence, dtype=tf.int32)

批量大小1运行时,我运行代码没有任何问题。但是,当批量大小大于 1 时,我总是会收到以下错误,该错误指向上面的第一个代码片段。

InvalidArgumentError (see above for traceback): TensorArray sentence_splitter/map/TensorArray_1_1: Could not write to TensorArray index 10 because the value shape is [4] which is incompatible with the TensorArray's inferred element shape: [6] (consider setting infer_shape=False).

我不明白 Tensorflow 试图通过此错误表达什么!如果有人能解释这个错误,那就太好了。谢谢!

最佳答案

当你的batch size大于1时,在这段代码之后

sentence = tf.map_fn(lambda x: tf.string_split([x], delimiter=' ').values, sentence, dtype=tf.string)

tf.string_split()函数作用于不同的句子,产生不同数量的分割结果。各个维度的不兼容导致最终结果无法存储到张量中,从而出现错误。这清楚吗?

关于python - 无法写入 TensorArray 索引,因为值形状与 TensorArray 的推断元素形状不兼容,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54052242/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com