gpt4 book ai didi

python - TensorFlow - 分割和挤压

转载 作者:行者123 更新时间:2023-12-02 09:24:23 25 4
gpt4 key购买 nike

我是 TensorFlow 新手,正在格式化一些数据以输入循环神经网络。我的数据由输入占位符 x 的 3d 张量给出。我想沿着第三个维度分割x,为此我有(注意n_timesteps对应于x沿着第三个维度的长度):

# Split the previous 3d tensor to get a list of 'n_timesteps' 2d tensors of
# shape (batch_size, features_dimension)
x = tf.split (x, n_timesteps, axis = 2)

尽管如此,正如我尝试使用numpy:

x = np.split (x, n_timesteps, axis = 2)

如果x是一个3d ndarray,那么np.split将返回一个n_timesteps数组列表,其中维度 3,因此第 3 维度是单维度。使用 numpy ,我知道我可以使用 np.squeeze 以及列表理解来轻松解决这个问题,以删除单例维度:

x = [np.squeeze(a, axis=2) for a in np.split(x, n_timesteps, axis=2)]

但是我怎样才能在 TF 上做同样的事情呢?

最佳答案

您可能正在寻找tf.unstack操作:

x = tf.unstack(x, axis=2)

关于python - TensorFlow - 分割和挤压,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42517452/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com