gpt4 book ai didi

python - 如何在 TensorFlow 中实现 Numpy where 索引?

转载 作者:太空宇宙 更新时间:2023-11-03 13:58:57 25 4
gpt4 key购买 nike

我有以下使用 numpy.where 的操作:

    mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1,0,0],[0,1,0],[0,0,1]])
mat[np.where(index>0)] = 100
print(mat)

如何在 TensorFlow 中实现等价物?

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
indi = tf.where(tf_index>0)
tf_mat[indi] = -1 <===== not allowed

最佳答案

假设您想要创建一个带有一些替换元素的新张量,而不是更新变量,您可以这样做:

import numpy as np
import tensorflow as tf

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
tf_mat = tf.where(tf_index > 0, -tf.ones_like(tf_mat), tf_mat)
with tf.Session() as sess:
print(sess.run(tf_mat))

输出:

[[-1  2  3]
[ 4 -1 6]
[ 7 8 -1]]

关于python - 如何在 TensorFlow 中实现 Numpy where 索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51588899/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com