gpt4 book ai didi

python - 为什么在 Keras 子类化 API 中,永远不会调用 call 方法,而是通过调用此类的对象来传递输入?

转载 作者:行者123 更新时间:2023-12-01 14:02:48 25 4
gpt4 key购买 nike

在使用 Keras 子类化 API 创建模型时,我们编写了一个自定义模型类并定义了一个名为 call(self, x) 的函数。 (主要是写前向传递)它需要一个输入。然而,这个方法永远不会被调用,而是将输入传递给 call ,它作为 model(images) 传递给此类的对象.

我们怎么能叫这个model当我们没有实现 Python 特殊方法时,对象和传递值,__call__在类里

class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')

def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)

# Create an instance of the model
model = MyModel()

使用 tf.GradientTape 训练模型:
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

train_loss(loss)
train_accuracy(labels, predictions)

输入不应该像下面这样传递:
model = MyModel()
model.call(images)

最佳答案

其实__call__方法在 Layer 中实现类,由 Network 继承类,由 Model 继承类(class):

class Layer(module.Module):
def __call__(self, inputs, *args, **kwargs):

class Network(base_layer.Layer):

class Model(network.Network):

所以 MyClass将继承这个 __call__方法。

附加信息:

所以实际上我们所做的是覆盖继承的 call方法,其中新增 call然后将从继承的 __call__ 调用方法方法。这就是为什么我们不需要做 model.call() .
所以当我们调用我们的模型实例时,它继承了 __call__方法会自动执行,调用我们自己的 call方法。

关于python - 为什么在 Keras 子类化 API 中,永远不会调用 call 方法,而是通过调用此类的对象来传递输入?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59473267/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com