gpt4 book ai didi

tensorflow - 将 Keras ModelCheckpoints 保存在 Google Cloud Bucket 中

转载 作者:行者123 更新时间:2023-12-04 01:31:42 26 4
gpt4 key购买 nike

我正在使用带有 TensorFlow 后端的 Keras 在 Google Cloud Machine Learning Engine 上训练 LSTM 网络。在对 gcloud 和我的 python 脚本进行一些调整后,我管理它来部署我的模型并执行成功的训练任务。

然后我尝试使用 Keras modelCheckpoint callback 让我的模型在每个 epoch 后保存检查点.使用 Google Cloud 运行本地培训作业可以按预期完美运行。在每个 epoch 之后,权重都存储在指定的路径中。但是当我尝试在 Google Cloud Machine Learning Engine 上在线运行相同的作业时,weights.hdf5没有写入我的谷歌云存储桶。相反,我收到以下错误:

...
File "h5f.pyx", line 71, in h5py.h5f.open (h5py/h5f.c:1797)
IOError: Unable to open file (Unable to open file: name =
'gs://.../weights.hdf5', errno = 2, error message = 'no such file or
directory', flags = 0, o_flags = 0)

我调查了这个问题,结果发现 Bucket 本身没有问题,如 Keras Tensorboard callback确实可以正常工作并将预期的输出写入同一个存储桶。我还确定了 h5py通过在 setup.py 中提供它来包含位于:
├── setup.py
└── trainer
├── __init__.py
├── ...

实际包含在 setup.py如下图所示:
# setup.py
from setuptools import setup, find_packages

setup(name='kerasLSTM',
version='0.1',
packages=find_packages(),
author='Kevin Katzke',
install_requires=['keras','h5py','simplejson'],
zip_safe=False)

我想问题归结为 Python 无法访问 GCS open用于 I/O,因为它提供了一个自定义实现:
import tensorflow as tf
from tensorflow.python.lib.io import file_io

with file_io.FileIO("gs://...", 'r') as f:
f.write("Hi!")

在检查了 Keras modelCheckpoint 回调如何实现实际的文件写入后,结果证明它正在使用 h5py.File()对于 I/O:
 with h5py.File(filepath, mode='w') as f:
f.attrs['keras_version'] = str(keras_version).encode('utf8')
f.attrs['backend'] = K.backend().encode('utf8')
f.attrs['model_config'] = json.dumps({
'class_name': model.__class__.__name__,
'config': model.get_config()
}, default=get_json_type).encode('utf8')

而作为 h5py packageHDF5 binary data format 的 Pythonic 接口(interface) h5py.File()似乎调用了底层 HDF5据我所知,用 Fortran 编写的功能: source , documentation .

如何解决此问题并使 modelCheckpoint 回调写入我的 GCS 存储桶?有没有办法让“猴子补丁”以某种方式覆盖 hdf5 文件的打开方式以使其使用 GCS 的 file_io.FileIO() ?

最佳答案

对我来说,最简单的方法是使用 gsutil。

model.save('model.h5')
!gsutil -m cp model.h5 gs://name-of-cloud-storage/model.h5

关于tensorflow - 将 Keras ModelCheckpoints 保存在 Google Cloud Bucket 中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45585104/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com