gpt4 book ai didi

python - 训练前运行tensorflow模型,否则它不会训练?

转载 作者:行者123 更新时间:2023-12-05 08:03:58 24 4
gpt4 key购买 nike

今天在运行tensorflow代码的时候发现了一件很有意思的事情:

import matplotlib
from matplotlib import pyplot as plt
import tensorflow as tf
matplotlib.rcParams['figure.figsize'] = [9, 6]

x = tf.linspace(-2., 2., 201)
def f(x):
y = x**2 + 2*x - 5
return y
y = f(x) + tf.random.normal(shape=[201])

class Model(tf.keras.Model):
def __init__(self, units):
super().__init__()
self.dense1 = tf.keras.layers.Dense(units=units,
activation=tf.nn.relu,
kernel_initializer=tf.random.normal,
bias_initializer=tf.random.normal)
self.dense2 = tf.keras.layers.Dense(1)

def call(self, x, training=True):
# For Keras layers/models, implement `call` instead of `__call__`.
x = x[:, tf.newaxis]
x = self.dense1(x)
x = self.dense2(x)
return tf.squeeze(x, axis=1)

model = Model(64)

test = model(x) ################## model couldn't train without this line ####

variables = model.variables
optimizer = tf.optimizers.SGD(learning_rate=0.001)

for step in range(1000):
with tf.GradientTape() as tape:
prediction = model(x)
error = (y-prediction)**2
mean_error = tf.reduce_mean(error)
gradient = tape.gradient(mean_error, variables)
optimizer.apply_gradients(zip(gradient, variables))

if step % 100 == 0:
print(f'Mean squared error: {mean_error.numpy():0.3f}')

模型本身非常简单。有趣的是注释行。不调用模型一次,通过 test = model(x),模型根本不会训练!!!例如,如果我删除这一行。结果将是:

Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782
Mean squared error: 21.782

为什么需要这条线?

最佳答案

在训练模型之前,您必须构建并编译它。

构建模型会根据训练数据的 input_shape 创建模型的所有变量。

编译模型会设置您希望在训练期间使用的优化器和损失函数。

当您调用模型时,它会根据您插入的数据的形状自动构建。因此,您可以在调用模型后对其进行训练。

关于python - 训练前运行tensorflow模型,否则它不会训练?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70701009/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com