gpt4 book ai didi

python - 如何禁用警告 "tensorflow:Method (on_train_batch_end) is slow compared to the batch update (). Check your callbacks"

转载 作者:太空宇宙 更新时间:2023-11-03 19:52:25 24 4
gpt4 key购买 nike

我正在尝试实现Stochastic Weight Averaging (SWA)使用keras风格的tensorflow 2.0,所以我需要每一步更新SWA模型权重。我已经编写了一个自定义回调来执行此操作,但每一步都会收到警告。以下是一些详细信息:

我的自定义回调:


class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, valid_data, output_path, swa_alpha=0.99, eval_every=500, eval_batch=16, fold=None):
self.valid_inputs = valid_data[0]
self.valid_outputs = valid_data[1]
self.eval_batch = eval_batch
self.swa_alpha = swa_alpha
self.fold = fold
self.output_path = output_path
self.rho_value = -1 # record the best rho for report
self.eval_every = eval_every

def on_train_begin(self, logs={}):
self.swa_weights = self.model.get_weights()

def on_batch_end(self, batch, logs={}):

# update swa parameters
alpha = min(1 - 1 / (batch + 1), self.swa_alpha)
current_weights = self.model.get_weights()
for i, layer in enumerate(self.model.layers):
self.swa_weights[i] = alpha * self.swa_weights[i] + (1 - alpha) * current_weights[i]

# validation
if batch > 0 and batch % self.eval_every == 0:
# do validation
val_pred = self.model.predict(self.valid_inputs, batch_size=self.eval_batch)
rho_val = compute_spearmanr(self.valid_outputs, val_pred) # the metric

# set the swa parameters and do validation
self.model.set_weights(self.swa_weights)
swa_val_pred = self.model.predict(self.valid_inputs, batch_size=self.eval_batch)
swa_rho_val = compute_spearmanr(self.valid_outputs, swa_val_pred)

# reset the original parameters
self.model.set_weights(current_weights)

# check whether to save model and update best rho value
if rho_val > self.rho_value:
self.rho_value = rho_val
self.model.save_weights(f'{self.output_path}/fold-{fold}-best.h5')

del current_weights
gc.collect()

输出是这样的:

WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.428264). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.464315). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.502968). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.518413). Check your callbacks.

我每一步都会收到警告,这意味着如果不运行验证代码,更新 SWA 参数的代码(self.model.get_weights() 和下面的 for 循环)就足够慢了。

我知道更新参数非常慢,因为 model.get_weights()model.set_weights()都会对参数进行深度复制(根据我的实验,新的 numpy ndarray 的新列表)。

我认为我的SWA实现没有任何问题(如果有任何错误请告诉我),所以我只想禁用警告。

我尝试过的:

  1. 添加代码os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"禁用警告。
  2. 设置verbose20model.fit() ,即model.fit(..., verbose=2, ...)model.fit(..., verbose=0, ...)

两者都不起作用。

有什么想法吗?感谢您提前提供的任何帮助!

最佳答案

这不是一个非常令人满意的答案,但 TF_CPP_MIN_LOG_LEVEL 不起作用是一个已知问题:TF_CPP_MIN_LOG_LEVEL does not work with TF2.0 dev20190820 .

我能够通过此处的玩具示例在 tensorflow==2.1.0-rc1 上重现您的问题:

import os
import time
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
tf.get_logger().setLevel("WARNING")
tf.autograph.set_verbosity(2)

print(tf.__version__)

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])

class CustomCallback(tf.keras.callbacks.Callback):

def on_train_batch_end(self, batch, logs=None):
time.sleep(3)

model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1, callbacks=[CustomCallback()])
2.1.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 31s 3us/step
Train on 60000 samples
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (3.002797). Check your callbacks.
32/60000 [..............................] - ETA: 1:57:38 - loss: 2.4674 - accuracy: 0.0938WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (3.002938). Check your callbacks.
...

没有标准建议(os.environ['TF_CPP_MIN_LOG_LEVEL']tf.get_logger().setLevel("WARNING")tf. autograph.set_verbosity(2)) 有效,我怀疑您必须等到上述问题得到解决。

关于python - 如何禁用警告 "tensorflow:Method (on_train_batch_end) is slow compared to the batch update (). Check your callbacks",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59749499/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com