gpt4 book ai didi

python - tensorflow keras RandomForestModel get_config() 为空

转载 作者:行者123 更新时间:2023-12-05 04:26:25 25 4
gpt4 key购买 nike

我希望能够查看传递给 keras 的 RandomForestModel 的超参数。我认为这应该可以通过 model.get_config() 实现。但是,在创建和训练模型后,get_config() 总是返回一个空字典。

这是在我的 RandomForestWrapper 类中创建模型的函数:

def add_new_model(self, model_name, params):

self.train_test_split()

model = tfdf.keras.RandomForestModel(
random_seed=params["random_seed"],
num_trees=params["num_trees"],
categorical_algorithm=params["categorical_algorithm"],
compute_oob_performances=params["compute_oob_performances"],
growing_strategy=params["growing_strategy"],
honest=params["honest"],
max_depth=params["max_depth"],
max_num_nodes=params["max_num_nodes"]
)

print(model.get_config())
self.models.update({model_name: model})
print(f"{model_name} added")

示例参数:

params_v2 = {
"random_seed": 123456,
"num_trees": 1000,
"categorical_algorithm": "CART",
"compute_oob_performances": True,
"growing_strategy": "LOCAL",
"honest": True,
"max_depth": 8,
"max_num_nodes": None
}

然后我实例化类并训练模型:

rf_models = RF(data, obs_col="obs", class_col="cell_type")
rf_models.add_new_model("model_2", params_v2)
rf_models.train_model("model_2", verbose=False, metrics=["Accuracy"])

model = rf_models.models["model_2"]
model.get_config()

##
{}

在模型摘要中,我可以看到参数已被接受。

最佳答案

关于 get_config(),请注意 docs 状态:

Returns the config of the Model.

Config is a Python dictionary (serializable) containing theconfiguration of an object, which in this case is a Model. This allowsthe Model to be be reinstantiated later (without its trained weights)from this configuration.

Note that get_config() does not guarantee to return a fresh copy ofdict every time it is called. The callers should make a copy of thereturned dict if they want to modify it.

Developers of subclassed Model are advised to override this method,and continue to update the dict from super(MyModel, self).get_config()to provide the proper configuration of this Model. The default configis an empty dict. Optionally, raise NotImplementedError to allow Kerasto attempt a default serialization.

我认为你可以做的就是调用 model.learner_params 来获得你想要的细节:

import tensorflow_decision_forests as tfdf
import pprint

params_v2 = {
"random_seed": 123456,
"num_trees": 1000,
"categorical_algorithm": "CART",
"compute_oob_performances": True,
"growing_strategy": "LOCAL",
"honest": True,
"max_depth": 8,
"max_num_nodes": None
}

model = tfdf.keras.RandomForestModel().from_config(params_v2)
pprint.pprint(model.learner_params)
{'adapt_bootstrap_size_ratio_for_maximum_training_duration': False,
'allow_na_conditions': False,
'bootstrap_size_ratio': 1.0,
'bootstrap_training_dataset': True,
'categorical_algorithm': 'CART',
'categorical_set_split_greedy_sampling': 0.1,
'categorical_set_split_max_num_items': -1,
'categorical_set_split_min_item_frequency': 1,
'compute_oob_performances': True,
'compute_oob_variable_importances': False,
'growing_strategy': 'LOCAL',
'honest': True,
'honest_fixed_separation': False,
'honest_ratio_leaf_examples': 0.5,
'in_split_min_examples_check': True,
'keep_non_leaf_label_distribution': True,
'max_depth': 8,
'max_num_nodes': None,
'maximum_model_size_in_memory_in_bytes': -1.0,
'maximum_training_duration_seconds': -1.0,
'min_examples': 5,
'missing_value_policy': 'GLOBAL_IMPUTATION',
'num_candidate_attributes': 0,
'num_candidate_attributes_ratio': -1.0,
'num_oob_variable_importances_permutations': 1,
'num_trees': 1000,
'pure_serving_model': False,
'random_seed': 123456,
'sampling_with_replacement': True,
'sorting_strategy': 'PRESORT',
'sparse_oblique_normalization': None,
'sparse_oblique_num_projections_exponent': None,
'sparse_oblique_projection_density_factor': None,
'sparse_oblique_weights': None,
'split_axis': 'AXIS_ALIGNED',
'uplift_min_examples_in_treatment': 5,
'uplift_split_score': 'KULLBACK_LEIBLER',
'winner_take_all': True}

关于python - tensorflow keras RandomForestModel get_config() 为空,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73058773/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com