gpt4 book ai didi

python - Pytorch model.train() 和教程中编写的单独的 train() 函数

转载 作者:行者123 更新时间:2023-11-30 09:42:42 26 4
gpt4 key购买 nike

我是 PyTorch 的新手,我想知道您是否可以向我解释 PyTorch 中的默认 model.train() 与此处的 train() 函数之间的一些关键区别。

另一个 train() 函数位于关于文本分类的官方 PyTorch 教程中,对于是否在训练结束时存储模型权重感到困惑。

https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

learning_rate = 0.005

criterion = nn.NLLLoss()

def train(category_tensor, line_tensor):
hidden = rnn.initHidden()
rnn.zero_grad()
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
loss = criterion(output, category_tensor)
loss.backward()
# Add parameters' gradients to their values, multiplied by learning rate
for p in rnn.parameters():
p.data.add_(-learning_rate, p.grad.data)
return output, loss.item()

这就是函数。然后以这种形式多次调用该函数:

n_iters = 100000
print_every = 5000
plot_every = 1000
record_every = 500

# Keep track of losses for plotting
current_loss = 0
all_losses = []
predictions = []
true_vals = []

def timeSince(since):
now = time.time()
s = now - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)

start = time.time()

for iter in range(1, n_iters + 1):
category, line, category_tensor, line_tensor = randomTrainingExample()
output, loss = train(category_tensor, line_tensor)
current_loss += loss

if iter % print_every == 0:
guess, guess_i = categoryFromOutput(output)
correct = 'O' if guess == category else 'X (%s)' % category
print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))

if iter % plot_every == 0:
all_losses.append(current_loss / plot_every)
current_loss = 0

if iter % record_every == 0:
guess, guess_i = categoryFromOutput(output)
predictions.append(guess)
true_vals.append(category)

对我来说,模型权重似乎没有被保存或更新,而是在每次迭代时被重写,这样编写。它是否正确?或者模型训练是否正确?

此外,如果我使用默认函数 model.train(),主要优点是什么? model.train() 是否执行与上面的 train() 函数或多或少相同的功能?

最佳答案

根据源代码here , model.train() 将模块设置为训练模式。因此,它基本上告诉您的模型您正在训练该模型。这仅对某些模块有影响,例如 dropout、batchnorm 等,这些模块在训练/评估模式下的行为有所不同。在 model.train() 的情况下,模型知道它必须学习各层。

您可以调用 model.eval()model.train(mode=False) 来告诉模型它没有新的东西需要学习,并且模型正在运行用于测试目的。

model.train() 只是设置模式。它实际上并不训练模型。

您上面使用的

train() 实际上是训练模型,即计算梯度并进行反向传播来学习权重。

从官方 pytorch 讨论论坛 here 了解有关 model.train() 的更多信息.

关于python - Pytorch model.train() 和教程中编写的单独的 train() 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56758445/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com