gpt4 book ai didi

python - 在 tf.keras 中正确设置 GAN 实现中的 .trainable 变量

转载 作者:行者123 更新时间:2023-12-01 06:56:02 25 4
gpt4 key购买 nike

我对 GAN 实现中的 tf.keras.model.trainable 语句感到困惑。

鉴于以下代码片段(取自 this repo ):

class GAN():

def __init__(self):

...

# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])

# Build the generator
self.generator = self.build_generator()

# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)

# For the combined model we will only train the generator
self.discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)

# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

def build_generator(self):

...

return Model(noise, img)

def build_discriminator(self):

...

return Model(img, validity)

def train(self, epochs, batch_size=128, sample_interval=50):

# Load the dataset
(X_train, _), (_, _) = mnist.load_data()

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):

# ---------------------
# Train Discriminator
# ---------------------

# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]

noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

# Generate a batch of new images
gen_imgs = self.generator.predict(noise)

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# ---------------------
# Train Generator
# ---------------------

noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)

在模型定义期间self.combined,鉴别器的权重设置为self.discriminator.trainable = False,但从未重新打开。

尽管如此,在训练循环期间,鉴别器的权重将会改变:

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

并且在以下期间保持不变:

# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)

这是我没想到的。

当然,这是训练 GAN 的正确(迭代)方法,但我不明白为什么我们在做一些事情之前不必通过 self.discriminator.trainable = True对判别器进行培训。

如果有人对此有解释,那就太好了,我想这是理解的关键点。

最佳答案

当您对 github 存储库中的代码有疑问时,检查问题(开放的和已关闭的)通常是个好主意。 This issue解释了为什么该标志设置为False。它说,

Since self.discriminator.trainable = False is set after the discriminator is compiled, it will not affect the training of the discriminator. However since it is set before the combined model is compiled the discriminator layers will be frozen when the combined model is trained.

还讨论了freezing keras layers .

关于python - 在 tf.keras 中正确设置 GAN 实现中的 .trainable 变量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58803868/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com