gpt4 book ai didi

python - 类型错误 : Value passed to parameter 'indices' has DataType float32 not in list of allowed values: int32, int64

转载 作者:行者123 更新时间:2023-12-01 07:47:57 30 4
gpt4 key购买 nike

我使用 Keras 构建模型,模型中有两个输入,其数据类型为“int32”。然后我使用 keras Lamba 层通过 K.gather(reference,indexs) 在嵌入矩阵中查找。我看到索引应该是 int 张量,我认为我的代码满足这一点,我不知道为什么会出现错误。我真的需要帮助!

    input_A = Input(batch_shape=(128,1),name='A_input',dtype='int32')
input_B = Input(batch_shape=(128,1),name='B_input',dtype='int32')

input_A_ = Lambda(lambda x:K.reshape(x,(-1,)))(input_A)
input_B_ = Lambda(lambda x:K.reshape(x, (-1,)))(input_B)

input_A__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_A_)
input_B__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_B_)

embedded_text_A = Lambda(lambda x:K.gather(M1,x))(input_A__)
embedded_text_B = Lambda(lambda x:K.gather(M1,x))(input_B__)

最佳答案

出于某种神秘的原因,如果将 K.cast() 放入 lambda 中,它将正常工作:

input_A = Input(batch_shape=(128,1), name='A_input', dtype='int32')
input_B = Input(batch_shape=(128,1), name='B_input', dtype='int32')

input_A_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_A)
input_B_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_B)

embedded_text_A = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_A_)
embedded_text_B = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_B_)

因此,Lambda 层会在内部进行一些奇怪的数据类型转换。

我认为这是某种错误,我的假设是隐式转换发生在 Lambda__call__ (which is inherited from Layer.__call__) 内部。 。我无法跟踪它,但我猜“隐式转换”错误位于 Layer.__call__ 中的某个位置,但在 451 行之前,其中实际调用了 Lambda.call。

关于python - 类型错误 : Value passed to parameter 'indices' has DataType float32 not in list of allowed values: int32, int64,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56357004/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com