gpt4 book ai didi

python - 循环张量并将函数应用于每个元素

转载 作者:行者123 更新时间:2023-11-28 18:57:19 24 4
gpt4 key购买 nike

我想遍历包含 Int 列表的张量,并对每个元素应用一个函数。在函数中,每个元素都将从 python 的字典中获取值。我已经尝试使用tf.map_fn 的简单方法,它将在add 函数上工作,例如以下代码:

import tensorflow as tf

def trans_1(x):
return x+10

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
# output: [11 12 13]

但是下面的代码抛出KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32异常:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
return kv_dict[x]

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))

我的 tensorflow 版本是 1.13.1。先谢谢了。

最佳答案

有一种简单的方法可以实现您正在尝试的目标。

问题是传递给 map_fn 的函数必须有张量作为参数和张量作为返回值。但是,您的函数 trans_2 将纯 python int 作为参数并返回另一个 python int。这就是为什么您的代码不起作用的原因。

但是,TensorFlow 提供了一种简单的方法来包装普通的 python 函数,即 tf.py_func,您可以在您的案例中使用它,如下所示:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
return kv_dict[x]

def wrapper(x):
return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)

a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))

你可以看到我添加了一个包装函数,它需要张量参数并返回一个张量,这就是它可以在 map_fn 中使用的原因。使用强制转换是因为 python 默认使用 64 位整数,而 TensorFlow 使用 32 位整数。

关于python - 循环张量并将函数应用于每个元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56804882/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com