gpt4 book ai didi

python - 如何在 AutoKeras 1.0 中保存/加载模型

转载 作者:行者123 更新时间:2023-12-05 07:15:41 25 4
gpt4 key购买 nike

我正在使用 AutoKeras 1.0,我无法理解我应该如何保存和重新加载经过训练的模型(加上权重等)。

我可以使用类似于以下的代码轻松训练模型:

num_data = 500
train_x = common.generate_structured_data(num_data)
train_y = common.generate_one_hot_labels(num_instances=num_data, num_classes=3)
clf = ak.StructuredDataClassifier(
column_names=common.COLUMN_NAMES_FROM_NUMPY,
max_trials=1,
seed=common.SEED)
clf.fit(train_x, train_y, epochs=4, validation_data=(train_x, train_y))
loss = clf.evaluate(train_x, train_y)
print(loss)

但是,我无法从文档中得知如何保存此模型并稍后在另一个程序中重用它。我试过找到“最佳”模型并保存它,如下所示:

preprocess_graph, best_model = clf.tuner.get_best_model()
best_model.save("testmodel.h5")

但是,当我尝试再次加载该模型时,我得到以下信息:

new_model = load_model("testmodel.h5")

---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-12-bd01053bfeda> in <module>
----> 1 new_model = load_model("testmodel.h5")

/opt/conda/lib/python3.7/site-packages/keras/engine/saving.py in load_wrapper(*args, **kwargs)
490 os.remove(tmp_filepath)
491 return res
--> 492 return load_function(*args, **kwargs)
493
494 return load_wrapper

/opt/conda/lib/python3.7/site-packages/keras/engine/saving.py in load_model(filepath, custom_objects, compile)
582 if H5Dict.is_supported_type(filepath):
583 with H5Dict(filepath, mode='r') as h5dict:
--> 584 model = _deserialize_model(h5dict, custom_objects, compile)
585 elif hasattr(filepath, 'write') and callable(filepath.write):
586 def load_function(h5file):

/opt/conda/lib/python3.7/site-packages/keras/engine/saving.py in _deserialize_model(h5dict, custom_objects, compile)
272 raise ValueError('No model found in config.')
273 model_config = json.loads(model_config.decode('utf-8'))
--> 274 model = model_from_config(model_config, custom_objects=custom_objects)
275 model_weights_group = h5dict['model_weights']
276

/opt/conda/lib/python3.7/site-packages/keras/engine/saving.py in model_from_config(config, custom_objects)
625 '`Sequential.from_config(config)`?')
626 from ..layers import deserialize
--> 627 return deserialize(config, custom_objects=custom_objects)
628
629

/opt/conda/lib/python3.7/site-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
166 module_objects=globs,
167 custom_objects=custom_objects,
--> 168 printable_module_name='layer')

/opt/conda/lib/python3.7/site-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
145 config['config'],
146 custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 147 list(custom_objects.items())))
148 with CustomObjectScope(custom_objects):
149 return cls.from_config(config['config'])

/opt/conda/lib/python3.7/site-packages/keras/engine/network.py in from_config(cls, config, custom_objects)
1054 # First, we create all layers and enqueue nodes to be processed
1055 for layer_data in config['layers']:
-> 1056 process_layer(layer_data)
1057
1058 # Then we process nodes in order of layer depth.

/opt/conda/lib/python3.7/site-packages/keras/engine/network.py in process_layer(layer_data)
1040
1041 layer = deserialize_layer(layer_data,
-> 1042 custom_objects=custom_objects)
1043 created_layers[layer_name] = layer
1044

/opt/conda/lib/python3.7/site-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
166 module_objects=globs,
167 custom_objects=custom_objects,
--> 168 printable_module_name='layer')

/opt/conda/lib/python3.7/site-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
147 list(custom_objects.items())))
148 with CustomObjectScope(custom_objects):
--> 149 return cls.from_config(config['config'])
150 else:
151 # Then `cls` may be a function returning a class.

/opt/conda/lib/python3.7/site-packages/keras/engine/base_layer.py in from_config(cls, config)
1177 A layer instance.
1178 """
-> 1179 return cls(**config)
1180
1181 def count_params(self):

/opt/conda/lib/python3.7/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper

TypeError: __init__() got an unexpected keyword argument 'ragged'

我做错了吗还是有更好的方法?

最佳答案

你可以试试这个来加载保存的模型:

将tensorflow导入为tf
new_model = tf.keras.models.load_model('testmodel.h5')

关于python - 如何在 AutoKeras 1.0 中保存/加载模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59533584/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com