gpt4 book ai didi

python - 将单位矩阵连接到每个向量

转载 作者:太空宇宙 更新时间:2023-11-04 04:28:57 26 4
gpt4 key购买 nike

我想通过向输入向量添加几个不同的后缀来修改我的输入。例如,如果(单个)输入是 [1, 5, 9, 3] 我想像这样创建三个向量(存储为矩阵):

[[1, 5, 9, 3, 1, 0, 0],
[1, 5, 9, 3, 0, 1, 0],
[1, 5, 9, 3, 0, 0, 1]]

当然,这只是一个观察结果,因此在本例中模型的输入是 (None, 4)。简单的方法是在其他地方准备输入数据(最有可能是 numpy)并相应地调整输入的形状。我可以做到,但我更愿意在 TensorFlow/Keras 中进行。

我已将问题隔离到这段代码中:

import keras.backend as K
from keras import Input, Model
from keras.layers import Lambda


def build_model(dim_input: int, dim_eye: int):
input = Input((dim_input,))
concat = Lambda(lambda x: concat_eye(x, dim_input, dim_eye))(input)
return Model(inputs=[input], outputs=[concat])


def concat_eye(x, dim_input, dim_eye):
x = K.reshape(x, (-1, 1, dim_input))
x = K.repeat_elements(x, dim_eye, axis=1)
eye = K.expand_dims(K.eye(dim_eye), axis=0)
eye = K.tile(eye, (-1, 1, 1))
out = K.concatenate([x, eye], axis=2)
return out


def main():
import numpy as np

n = 100
dim_input = 20
dim_eye = 3

model = build_model(dim_input, dim_eye)
model.compile(optimizer='sgd', loss='mean_squared_error')

x_train = np.zeros((n, dim_input))
y_train = np.zeros((n, dim_eye, dim_eye + dim_input))
model.fit(x_train, y_train)


if __name__ == '__main__':
main()

问题似乎出在 tile 函数的 shape 参数中的 -1 中。我尝试将其替换为 1None。每个都有自己的错误:

  • -1:model.fit

    期间出错
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected multiples[0] >= 0, but got -1
  • 1:执行 model.fit

    时出错
    tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [32,3,20] vs. shape[1] = [1,3,3]
  • :build_model期间出错:

    Failed to convert object of type <class 'tuple'> to Tensor. Contents: (None, 1, 1). Consider casting elements to a supported type.

最佳答案

您需要使用 K.shape()而不是获取输入张量的符号形状。这是因为批量大小为 None,因此传递 K.int_shape(x)[0]None-1 作为 K.tile() 的第二个参数的一部分将不起作用:

eye = K.tile(eye, (K.shape(x)[0], 1, 1))

关于python - 将单位矩阵连接到每个向量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53022148/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com