gpt4 book ai didi

python - Tensorflow - 具有批处理数据的输入矩阵的 matmul

转载 作者:IT老高 更新时间:2023-10-28 21:46:02 27 4
gpt4 key购买 nike

我有一些由 input_x 表示的数据。它是一个未知大小的张量(应该批量输入),每个项目的大小为 ninput_x 经历 tf.nn.embedding_lookup,因此 embed 现在具有维度 [?, n, m] 其中m 是嵌入大小,? 是指未知的批量大小。

这里有描述:

input_x = tf.placeholder(tf.int32, [None, n], name="input_x") 
embed = tf.nn.embedding_lookup(W, input_x)

我现在正尝试将输入数据(现在通过嵌入维度扩展)中的每个样本乘以矩阵变量 U,但我似乎不知道该怎么做那个。

我第一次尝试使用 tf.matmul 但由于形状不匹配而出现错误。然后我通过扩展 U 的维度并应用 batch_matmul 尝试了以下操作(我还尝试了 tf.nn.math_ops. 中的函数,结果是一样的):

U = tf.Variable( ... )    
U1 = tf.expand_dims(U,0)
h=tf.batch_matmul(embed, U1)

这通过了初始编译,但是当应用实际数据时,我收到以下错误:

In[0].dim(0) 和 In[1].dim(0) 必须相同:[64,58,128] vs [1,128,128]

我也知道为什么会这样 - 我复制了 U 的维度,现在是 1,但小批量大小为 64 ,不适合。

我怎样才能正确地对我的张量矩阵输入进行矩阵乘法(对于未知的批量大小)?

最佳答案

以前的答案已过时。目前tf.matmul()支持 rank > 2 的张量:

The inputs must be matrices (or tensors of rank > 2, representing batches of matrices), with matching inner dimensions, possibly after transposition.

还删除了 tf.batch_matmul() 并且 tf.matmul() 是进行批量乘法的正确方法。主要思想可以从以下代码中理解:

import tensorflow as tf
batch_size, n, m, k = 10, 3, 5, 2
A = tf.Variable(tf.random_normal(shape=(batch_size, n, m)))
B = tf.Variable(tf.random_normal(shape=(batch_size, m, k)))
tf.matmul(A, B)

现在您将收到形状为 (batch_size, n, k) 的张量。这就是这里发生的事情。假设您有 batch_size 矩阵 nxmbatch_size 矩阵 mxk。现在为每一对计算 nxm X mxk 给你一个 nxk 矩阵。您将拥有 batch_size 个。

请注意,这样的内容也是有效的:

A = tf.Variable(tf.random_normal(shape=(a, b, n, m)))
B = tf.Variable(tf.random_normal(shape=(a, b, m, k)))
tf.matmul(A, B)

会给你一个形状(a, b, n, k)

关于python - Tensorflow - 具有批处理数据的输入矩阵的 matmul,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38235555/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com