gpt4 book ai didi

python - TensorFlow:如何通过复制张量之一来连接张量?

转载 作者:太空宇宙 更新时间:2023-11-03 15:51:04 26 4
gpt4 key购买 nike

我想通过复制其中一个张量来连接两个张量。例如,我有两个形状为 [2, 2, 3] 和 [2, 3] 的张量。结果应为 [2, 2, 6] 的形状。

t1 = [[[ 1, 1, 1], [2, 2, 2]],
[[ 3, 3, 3], [4, 4, 4]]]
t2 = [[ 5, 5, 5], [6, 6, 6]]
"""
t3 = # some tf ops
t3 should be
t3 = [[[ 1, 1, 1, 5, 5, 5], [2, 2, 2, 5, 5, 5]],
[[ 3, 3, 3, 6, 6, 6], [4, 4, 4, 6, 6, 6]]]
"""

因此,如果两个张量的形状为 [10, 5, 8] 和 [10, 3],则结果的形状应为 [10, 5, 11]。

已更新

另一个例子:

t1 = np.reshape(np.arange(3*4*5), [3,4,5])
t2 = np.reshape(np.arange(3*1*2), [3,2])
""""
t3 should be
[[[ 0., 1., 2., 3., 4., 0., 1.],
[ 5., 6., 7., 8., 9., 0., 1.],
[ 10., 11., 12., 13., 14., 0., 1.],
[ 15., 16., 17., 18., 19., 0., 1.]],

[[ 20., 21., 22., 23., 24., 2., 3.],
[ 25., 26., 27., 28., 29., 2., 3.],
[ 30., 31., 32., 33., 34., 2., 3.],
[ 35., 36., 37., 38., 39., 2., 3.]],

[[ 40., 41., 42., 43., 44., 4., 5.],
[ 45., 46., 47., 48., 49., 4., 5.],
[ 50., 51., 52., 53., 54., 4., 5.],
[ 55., 56., 57., 58., 59., 4., 5.]]]
"""

最佳答案

函数tf.tile可以帮助你做到这一点。点击here获取函数的详细信息。

import numpy as np
import tensorflow as tf

t1 = np.reshape(np.arange(3*4*5), [3,4,5])
t2 = np.reshape(np.arange(3*1*2), [3,2])

# Keep t1 stay
t1_p = tf.placeholder(tf.float32, [3,4,5])

# Change t2 from shape(3,2) to shape(3,4,2) followed below two steps:
# 1. copy element of 2rd dimension of t2 as many times as you hope, as the updated example, it is 4
# 2. reshape the tiled tensor to compatible shape
t2_p = tf.placeholder(tf.float32, [3,2])
# copy the element of 2rd dimention of t2 by 4 times
t2_p_tiled = tf.tile(t2_p, [1, 4])
# reshape tiled t2 with shape(3,8) to the compatible shape(3,4,2)
t2_p_reshaped = tf.reshape(t2_p_tiled, [3,4,2])

# Concat t1 and changed t2, then you will get t3 you want
t3_p = tf.concat([t1_p, t2_p_reshaped], 2)

sess = tf.InteractiveSession()
t3 = sess.run(t3_p, {t1_p:t1, t2_p:t2})

print '*' * 20
print t1
print '*' * 20
print t2
print '*' * 20
print t3

# if you confused what the tf.tile did, you can print t2_p_tiled to see what happend
t2_tile = sess.run(t2_p_tiled, {t2_p:t2})
print '*' * 20
print t2_tile

关于python - TensorFlow:如何通过复制张量之一来连接张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41295168/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com