gpt4 book ai didi

python - Tensorflow:堆叠张量中的所有行对

转载 作者:太空狗 更新时间:2023-10-29 21:06:27 24 4
gpt4 key购买 nike

给定张量 t=[[1,2], [3,4]],我需要生成 ts=[[1,2,1,2], [ 1,2,3,4], [3,4,1,2], [3,4,3,4]]。也就是说,我需要将所有行对堆叠在一起。重要提示:张量的维度为 [None, 2],即。第一维是可变的。

我试过:

  • 使用 tf.while_loop 生成索引列表 idx=[[0, 0], [0, 1], [1, 0], [1, 1] ],然后是 tf.gather(ts, idx)。这行得通,但很乱,我不知道如何处理渐变。
  • 2 个 for 循环迭代 tf.unstack(t),将堆叠的行添加到缓冲区,然后 tf.stack(buffer)。如果第一个维度是可变的,这将不起作用。
  • 在广播中寻找灵感。例如,给定 x=t.expand_dims(t, 0), y=t.expand_dims(t, 1), s=tf.reshape(tf.add(x, y), [-1, 2] ) s 将是 [[2, 4], [4, 6], [4, 6], [6, 8]],即。每行组合的总和。但是我怎样才能做堆叠而不是求和呢?我已经失败了 2 天 :)

最佳答案

使用 tf.meshgrid() 和一些 reshape 的解决方案:

import tensorflow as tf
import numpy as np

t = tf.placeholder(tf.int32, [None, 2])
num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions

# Getting pair indices using tf.meshgrid:
idx_range = tf.range(num_rows)
pair_indices = tf.stack(tf.meshgrid(*[idx_range, idx_range]))
pair_indices = tf.transpose(pair_indices, perm=[1, 2, 0])

# Finally gathering the rows accordingly:
res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))

with tf.Session() as sess:
print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
# [[1 2 1 2]
# [3 4 1 2]
# [5 6 1 2]
# [1 2 3 4]
# [3 4 3 4]
# [5 6 3 4]
# [1 2 5 6]
# [3 4 5 6]
# [5 6 5 6]]

使用笛卡尔积的解决方案:

import tensorflow as tf
import numpy as np

t = tf.placeholder(tf.int32, [None, 2])
num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions

# Getting pair indices by computing the indices cartesian product:
row_idx = tf.range(num_rows)
row_idx_a = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 1), [1, num_rows]), 2)
row_idx_b = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 0), [num_rows, 1]), 2)
pair_indices = tf.concat([row_idx_a, row_idx_b], axis=2)

# Finally gathering the rows accordingly:
res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))

with tf.Session() as sess:
print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
# [[1 2 1 2]
# [1 2 3 4]
# [1 2 5 6]
# [3 4 1 2]
# [3 4 3 4]
# [3 4 5 6]
# [5 6 1 2]
# [5 6 3 4]
# [5 6 5 6]]

关于python - Tensorflow:堆叠张量中的所有行对,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50536721/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com