gpt4 book ai didi

tensorflow - 创建 keras 回调以在训练期间保存每批的模型预测和目标

转载 作者:行者123 更新时间:2023-12-02 23:37:57 28 4
gpt4 key购买 nike

我正在 Keras( tensorflow 后端)构建一个简单的顺序模型。在训练期间,我想检查各个训练批处理和模型预测。因此,我尝试创建一个自定义回调来保存每个训练批处理的模型预测和目标。但是,该模型并不是使用当前批处理进行预测,而是使用整个训练数据。

如何仅将当前训练批处理移交给回调

如何访问 Callback 保存在 self.predhis 和 self.targets 中的批处理和目标?

我当前的版本如下:

callback_list = [prediction_history((self.x_train, self.y_train))]

self.model.fit(self.x_train, self.y_train, batch_size=self.batch_size, epochs=self.n_epochs, validation_data=(self.x_val, self.y_val), callbacks=callback_list)

class prediction_history(keras.callbacks.Callback):
def __init__(self, train_data):
self.train_data = train_data
self.predhis = []
self.targets = []

def on_batch_end(self, epoch, logs={}):
x_train, y_train = self.train_data
self.targets.append(y_train)
prediction = self.model.predict(x_train)
self.predhis.append(prediction)
tf.logging.info("Prediction shape: {}".format(prediction.shape))
tf.logging.info("Targets shape: {}".format(y_train.shape))

最佳答案

注意:此答案已过时,仅适用于 TF1。检查@bers的answer在 TF2 上测试的解决方案。

<小时/>

模型编译后,占位符张量为 y_true位于model.targetsy_pred位于model.outputs .

要在每个批处理中保存这些占位符的值,您可以:

  1. 首先将这些张量的值复制到变量中。
  2. 评估 on_batch_end 中的这些变量,并存储结果数组。

现在第 1 步有点复杂,因为您必须添加 tf.assign op 到训练函数model.train_function 。使用当前的 Keras API,这可以通过提供 fetches 来完成K.function() 的参数当构建训练函数时。

model._make_train_function() ,有一行:

self.train_function = K.function(inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)

fetches包含 tf.assign 的参数可以通过 model._function_kwargs 提供操作(仅适用于Keras 2.1.0之后)。

举个例子:

from keras.layers import Dense
from keras.models import Sequential
from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf
import numpy as np

class CollectOutputAndTarget(Callback):
def __init__(self):
super(CollectOutputAndTarget, self).__init__()
self.targets = [] # collect y_true batches
self.outputs = [] # collect y_pred batches

# the shape of these 2 variables will change according to batch shape
# to handle the "last batch", specify `validate_shape=False`
self.var_y_true = tf.Variable(0., validate_shape=False)
self.var_y_pred = tf.Variable(0., validate_shape=False)

def on_batch_end(self, batch, logs=None):
# evaluate the variables and save them into lists
self.targets.append(K.eval(self.var_y_true))
self.outputs.append(K.eval(self.var_y_pred))

# build a simple model
# have to compile first for model.targets and model.outputs to be prepared
model = Sequential([Dense(5, input_shape=(10,))])
model.compile(loss='mse', optimizer='adam')

# initialize the variables and the `tf.assign` ops
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches} # use `model._function_kwargs` if using `Model` instead of `Sequential`

# fit the model and check results
X = np.random.rand(10, 10)
Y = np.random.rand(10, 5)
model.fit(X, Y, batch_size=8, callbacks=[cbk])

除非样本数量可以除以批处理大小,否则最终批处理的大小将与其他批处理的大小不同。所以K.variable()K.update()在这种情况下不能使用。您必须使用 tf.Variable(..., validate_shape=False)tf.assign(..., validate_shape=False)相反。

<小时/>

要验证保存的数组的正确性,可以在 training.py 中添加一行打印出打乱后的索引数组:

if shuffle == 'batch':
index_array = _batch_shuffle(index_array, batch_size)
elif shuffle:
np.random.shuffle(index_array)

print('Index array:', repr(index_array)) # Add this line

batches = _make_batches(num_train_samples, batch_size)

在拟合过程中应打印出打乱后的索引数组:

Epoch 1/1Index array: array([8, 9, 3, 5, 4, 7, 1, 0, 6, 2])10/10 [==============================] - 0s 23ms/step - loss: 0.5670

And you can check if cbk.targets is the same as Y[index_array]:

index_array = np.array([8, 9, 3, 5, 4, 7, 1, 0, 6, 2])
print(Y[index_array])
[[ 0.75325592 0.64857277 0.1926653 0.7642865 0.38901153]
[ 0.77567689 0.13573623 0.4902501 0.42897559 0.55825652]
[ 0.33760938 0.68195038 0.12303088 0.83509441 0.20991668]
[ 0.98367778 0.61325065 0.28973401 0.28734073 0.93399794]
[ 0.26097574 0.88219054 0.87951941 0.64887846 0.41996446]
[ 0.97794604 0.91307569 0.93816428 0.2125808 0.94381495]
[ 0.74813435 0.08036688 0.38094272 0.83178364 0.16713736]
[ 0.52609421 0.39218962 0.21022047 0.58569125 0.08012982]
[ 0.61276627 0.20679494 0.24124858 0.01262245 0.0994412 ]
[ 0.6026137 0.25620512 0.7398164 0.52558182 0.09955769]]

print(cbk.targets)
[array([[ 0.7532559 , 0.64857274, 0.19266529, 0.76428652, 0.38901153],
[ 0.77567691, 0.13573623, 0.49025011, 0.42897558, 0.55825651],
[ 0.33760938, 0.68195039, 0.12303089, 0.83509439, 0.20991668],
[ 0.9836778 , 0.61325067, 0.28973401, 0.28734073, 0.93399793],
[ 0.26097575, 0.88219053, 0.8795194 , 0.64887846, 0.41996446],
[ 0.97794604, 0.91307569, 0.93816429, 0.2125808 , 0.94381493],
[ 0.74813437, 0.08036689, 0.38094273, 0.83178365, 0.16713737],
[ 0.5260942 , 0.39218962, 0.21022047, 0.58569127, 0.08012982]], dtype=float32),
array([[ 0.61276627, 0.20679495, 0.24124858, 0.01262245, 0.0994412 ],
[ 0.60261369, 0.25620511, 0.73981643, 0.52558184, 0.09955769]], dtype=float32)]

可以看到cbk.targets中有两批(一个“完整批处理”的大小为 8,最后一批的大小为 2),行顺序与 Y[index_array] 相同。 .

关于tensorflow - 创建 keras 回调以在训练期间保存每批的模型预测和目标,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47079111/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com