gpt4 book ai didi

python-3.x - 现在正在加载 Eager TensorFlow 吗?

转载 作者:行者123 更新时间:2023-12-04 12:43:00 25 4
gpt4 key购买 nike

从 tf.keras.Model 继承的类中的权重目前似乎无法加载。我无法使用检查点从类外的 Example() 加载权重,因此我尝试在类内进行,这在所有方面都应该有效。它能够保存权重,就像在保存 Example() 时一样,但是它仍然无法加载它们。这是我的模型代码:

class Example(tf.keras.Model):
def __init__(self, cfg):
super(Example, self).__init__()

self.model = tf.keras.Sequential([
........layers.......
])

# Create saver
self.save_path = cfg.save_dir + cfg.extension
self.ckpt_prefix = self.save_path + '/ckpt'
self.saver = tf.train.Checkpoint(model=self.model)

def call(self, x_in):
x_out = self.model(x_in)
return x_out

def save(self):
self.saver.save(file_prefix=self.ckpt_prefix)

def load(self):
self.saver.restore(tf.train.latest_checkpoint(self.save_path))

这是我用来检查它是否加载的:

example = Example()
if Path(self.example.save_path).is_dir():
print(self.example.weights)
print(self.example.model.weights)
self.example.load()
print(self.example.weights)
print(self.example.model.weights)

输出:

[]
[]
[]
[]

这在 tensorflow 1.3 和 2.0 上都进行了测试,我可以确认在第一批之后权重不为空,并且它正在检查点/保存。

最佳答案

事实证明,TensorFlow 可以通过三种不同的方式进行检查点设置,具体取决于检查点设置的内容。

  1. 检查点对象只是一个变量。这会在调用 checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path)) 时立即恢复。

  2. 检查点对象是一个定义了输入形状的模型。这也立即恢复。

  3. 检查点对象是一个没有定义输入形状的模型。这是行为发生变化的地方,因为 TensorFlow 执行“延迟”恢复,并且在输入传递给模型之前不会恢复模型权重。

这是一个例子:

import os
import tensorflow as tf
import numpy as np

# Disable logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.logging.set_verbosity(tf.logging.ERROR)
tf.enable_eager_execution()

# Create model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty before training?", model.weights == [])

# Create optim, checkpoint
optimizer = tf.train.AdamOptimizer(0.001)
checkpoint = tf.train.Checkpoint(model=model)

# Make fake data
img = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
truth = np.random.uniform(0, 255, (1, 32, 32, 3)).astype(np.float32)
# Train
with tf.GradientTape() as tape:
logits = model(img)
loss = tf.losses.mean_squared_error(truth, logits)

# Compute/apply gradients
grads = tape.gradient(loss, model.trainable_weights)
grads_and_vars = zip(grads, model.trainable_weights)
optimizer.apply_gradients(grads_and_vars)

# Save model
checkpoint_path = './ckpt/'
checkpoint.save('./ckpt/')

# Check if weights update
print("Are weights empty after training?", model.weights == [])

# Reset model
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(256, 3, padding="same"),
tf.keras.layers.Conv2D(3, 3, padding="same")
])
print("Are weights empty when resetting model?", model.weights == [])

# Update checkpoint pointer
checkpoint = tf.train.Checkpoint(model=model)
# Restore values from the checkpoint
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path))

# This next line is REQUIRED to restore
#model(img)

print("Are weights empty after restoring from checkpoint?", model.weights == [])
print(status)
status.assert_existing_objects_matched()
status.assert_consumed()

输出:

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? True
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7f6256b4ddd8>
Traceback (most recent call last):
File "test.py", line 58, in <module>
status.assert_consumed()
File "/home/jpatts/.local/lib/python3.6/site-packages/tensorflow/python/training/checkpointable/util.py", line 1013, in assert_consumed
raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
AssertionError: Unresolved object in checkpoint: attributes {
name: "VARIABLE_VALUE"
full_name: "sequential/conv2d/kernel"
checkpoint_key: "model/layer-0/kernel/.ATTRIBUTES/VARIABLE_VALUE"
}

但是,取消注释行 model(img) 将产生以下输出:

Are weights empty before training? True
Are weights empty after training? False
Are weights empty when resetting model? True
Are weights empty after restoring from checkpoint? False
<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus object at 0x7ff62320fe48>

因此需要传递输入数据以正确恢复形状不变模型。

引用资料:

https://www.tensorflow.org/alpha/guide/checkpoints#delayed_restorations https://github.com/tensorflow/tensorflow/issues/27937

关于python-3.x - 现在正在加载 Eager TensorFlow 吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55719047/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com