gpt4 book ai didi

tensorflow - 根据条件修改一维张量

转载 作者:行者123 更新时间:2023-12-05 05:21:43 26 4
gpt4 key购买 nike

`我想对 tensorflow 中一维张量的每个元素应用一个条件,从而修改输入张量。例如,如果张量是:

y_true = tf.Variable([0.0, 0.3, 0.0, 0.4, 0.0, 0.1, 0.0, 0.0]),

我想检查每个元素是否大于0.1。如果是,则该元素变为 1,否则为 0。如何在 tensorflow 中完成同样的操作?

到目前为止,我一直在尝试编写一个 python 函数,然后使用 py_func 在 tensorflow 中使用它,但它不起作用。看下面的代码-

y_true = tf.Variable([0.0, 0.3, 0.0, 0.4, 0.0, 0.1, 0.0, 0.0])
with tf.Session() as sess:
y = tf.py_func(round_with_threshold, [y_true], tf.float32)
y.eval()

def round_with_threshold(arr):
threshold = 0.1
rounded_arr = np.zeros(arr.shape[0])
for i in range(arr.shape[0]):
if arr[i]>=threshold:
rounded_arr[i] = 1
else:
rounded_arr[i] = 0
return rounded_arr

是否可以在 tensorflow 中执行此操作而无需编写任何 python 函数?

最佳答案

您可以执行以下操作:

import tensorflow as tf

y_true = tf.Variable([0.0, 0.3, 0.0, 0.4, 0.0, 0.1, 0.0, 0.0])
comp_op = tf.greater(y_true, 0.1) # returns boolean tensor
cast_op = tf.cast(comp_op, tf.int32) # casts boolean tensor into int32

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(cast_op))

打印:[0 1 0 1 0 0 0 0]

关于tensorflow - 根据条件修改一维张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42437169/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com