gpt4 book ai didi

python - 当两个函数相互调用时,tf.function 是如何工作的

转载 作者:行者123 更新时间:2023-11-28 18:56:33 30 4
gpt4 key购买 nike

我使用 tensorflow==1.14 和 tf.enable_eager_execution() 构建我的模型,如下所示:

class Model:
def __init__(self):
self.embedding = tf.keras.layers.Embedding(10, 15)
self.dense = tf.keras.layers.Dense(10)

@tf.function
def inference(self, inp):
print('call function: inference')
inp_em = self.embedding(inp)
inp_enc = self.dense(inp_em)

return inp_enc

@tf.function
def fun(self, inp):
print('call function: fun')
return self.inference(inp)

model = Model()

当我第一次运行以下代码时:

a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))

输出是

call function: fun
call function: inference
call function: inference
====================
call function: inference

好像tensorflow为推理函数建立了三个图,我怎么能只为推理函数建立一个图呢?而且我还想知道当两个函数相互调用时 tf.function 是如何工作的。这是构建我的模型的正确方法吗?

最佳答案

有时 tf.function 的执行方式会给我们带来一些困惑 - 特别是当我们混合使用 print() 等 vanilla python 操作时。

我们应该记住,当我们用 tf.function 装饰一个函数时,它不再只是一个 python 函数。它的行为略有不同,以便在 TF 中快速高效地使用。绝大多数时候,行为的变化几乎不明显(除了速度提高!)但偶尔我们会遇到像这样的细微差别。

首先要注意的是,如果我们使用 tf.print() 代替 print(),那么我们会得到预期的输出:

class Model:
def __init__(self):
self.embedding = tf.keras.layers.Embedding(10, 15)
self.dense = tf.keras.layers.Dense(10)

@tf.function
def inference(self, inp):
tf.print('call function: inference')
inp_em = self.embedding(inp)
inp_enc = self.dense(inp_em)

return inp_enc

@tf.function
def fun(self, inp):
tf.print('call function: fun')
return self.inference(inp)

model = Model()

a = model.fun(np.array([1, 2, 3]))
print('=' * 20)
a = model.inference(np.array([1, 2, 3]))

输出:

call function: fun
call function: inference
====================
call function: inference

如果您的问题是现实世界问题的征兆,这可能就是解决方法!

这是怎么回事?

我们第一次调用用 tf.function 修饰的函数时,tensorflow 将构建一个执行图。为了做到这一点,它“跟踪”了 python 函数执行的 tensorflow 操作。

为了执行此跟踪,tensorflow 可能会多次调用装饰函数!

这意味着只有 python 操作(比如 print() 可以执行多次)但是 tf 操作比如 tf.print() 将表现为你通常会期望。

这种细微差别的副作用是我们应该知道 tf.function 装饰函数如何处理状态,但这不在您的问题范围内。查看original RFCthis github issue了解更多信息。

And I also want to know how tf.function woks when two functions call each other. Is this the right way to build my model?

一般来说,我们需要只用tf.function装饰“外部”函数(在你的例子中是.fun())但是如果您也可以直接调用内部函数,那么您也可以自由装饰它。

关于python - 当两个函数相互调用时,tf.function 是如何工作的,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57675242/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com