gpt4 book ai didi

python - 制作自定义(tf.function)函数二维数组

转载 作者:行者123 更新时间:2023-12-04 08:19:24 25 4
gpt4 key购买 nike

我正在尝试将变量 x,y 的函数的二维数组计算为 tf.function。该函数相当复杂,我想制作这个函数的二维数组,其中 x 和 y 采用值列表(tf.linspace)。我试过输入这样一个函数的相关参数,这就是它的样子

@tf.function
def function_matrix(xi, xf, yi, yf, num , some_other_args):

#part1
M=np.zeros((num, num))
xlist=tf.linspace(xi, xf, num)
ylist=tf.linspace(yi, yf, num)

#part2
for x in range(num):
for y in range(num):
M[x,y]=some_complicated_function(xlist[x], ylist[y], some_other_args) #this is also a @tf.function

return (M)
我遇到的问题是在 tf.function 中,如果我尝试访问像 xlist[x] 这样的数组元素,结果是 Tensor("strided_slice:0", shape=(), dtype=float64) .所以当在 some_complicated_function 中传递这个值时,我得到一个错误“设置一个带有序列的数组元素”。如果 function_matrix 不是 tf.function,则不会发生此类错误。有人可以帮忙吗?至于我哪里可能出错?或者我可以计算一个相当复杂函数的二维矩阵的任何替代方法?
任何帮助将不胜感激,谢谢!
我试过的 :
第 1 部分运行良好,如果我将 xlist 作为函数的输出返回,我会得到一个普通数组 tf.Tensor( [the_array_here], shape=(num,), dtype=float64) .同样,如果输出是 xlist[index],我得到 tf.Tensor( [the_element_here], shape=(), dtype=float64) .但是我是否尝试从函数中打印 xlist[index],我得到 Tensor("strided_slice:0", shape=(), dtype=float64) .所以我得出的结论是,tf 以某种方式将 xlist[index] 视为某种占位符。但我不知道为什么...

最佳答案

哦,好问题! tensorflow真的不喜欢 for循环,它是 python无法自动转换为 tensorflow 的代码graph representation .实现这一点的方法是生成要在张量中操作的网格。比方说:

xlist=[1,2] # this is a tf.Tensor
ylist=[1,2] # this is a tf.Tensor
然后,使用 tf.meshgrid , 你应该构造 xylist :
xylist=[[1,1], [1,2], [2,1], [2,2]] # this is a tf.Tensor
然后使用 tf.map_fn将您的功能应用于每一对。
M = tf.map_fn(xylist, some_complicated_function)
M = tf.reshape(M, (...))
请注意,如果 some_complicated_function包含任何非 tensorflow代码(或无法自动转换的代码),例如使用 numpy , pandas , pillow ...,您可以将其包装在 tf.py_function 中- 但现在这种方式违背了将函数转换为 tf.function 的目的. (编辑:我现在看到您说: # this is also a tf.function ,这意味着您不必将其包装在 tf.py_function 中)
您还可以包含 extra_args通过附加到 xylist 中的每一对(是的,每一对,即使它们是恒定的)。
TL;DR:使用 tf.map_fn而不是嵌套 for循环。

关于python - 制作自定义(tf.function)函数二维数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65561252/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com