gpt4 book ai didi

machine-learning - Kedro - 如何将嵌套参数直接传递给节点

转载 作者:行者123 更新时间:2023-12-04 10:08:32 28 4
gpt4 key购买 nike

kedro建议将参数存储在 conf/base/parameters.yml 中。让我们假设它看起来像这样:

step_size: 1
model_params:
learning_rate: 0.01
test_data_ratio: 0.2
num_train_steps: 10000

现在假设我有一些 data_engineering 管道,其 nodes.py 具有如下所示的函数:

def some_pipeline_step(num_train_steps):
"""
Takes the parameter `num_train_steps` as argument.
"""
pass

我如何将嵌套参数直接传递给 data_engineering/pipeline.py 中的此函数?我尝试失败:

from kedro.pipeline import Pipeline, node

from .nodes import split_data


def create_pipeline(**kwargs):
return Pipeline(
[
node(
some_pipeline_step,
["params:model_params.num_train_steps"],
dict(
train_x="train_x",
train_y="train_y",
),
)
]
)

我知道我可以使用 ['parameters'] 将所有参数传递到函数中,或者使用 ['params: 传递所有 model_params 参数model_params'] 但它看起来不优雅,我觉得一定有办法。如有任何意见,将不胜感激!

最佳答案

(免责声明:我是 Kedro 团队的一员)

谢谢你的提问。不幸的是,当前版本的 Kedro 不支持嵌套参数。临时解决方案是在节点内使用顶级键(正如您已经指出的那样)或使用某种参数过滤器来装饰节点函数,这也不优雅。

也许最可行的解决方案是自定义您的 ProjectContext (在 src/<package_name>/run.py 中)通过覆盖 _get_feed_dict 进行类方法如下:

class ProjectContext(KedroContext):
# ...


def _get_feed_dict(self) -> Dict[str, Any]:
"""Get parameters and return the feed dictionary."""
params = self.params
feed_dict = {"parameters": params}

def _add_param_to_feed_dict(param_name, param_value):
"""This recursively adds parameter paths to the `feed_dict`,
whenever `param_value` is a dictionary itself, so that users can
specify specific nested parameters in their node inputs.

Example:

>>> param_name = "a"
>>> param_value = {"b": 1}
>>> _add_param_to_feed_dict(param_name, param_value)
>>> assert feed_dict["params:a"] == {"b": 1}
>>> assert feed_dict["params:a.b"] == 1
"""
key = "params:{}".format(param_name)
feed_dict[key] = param_value

if isinstance(param_value, dict):
for key, val in param_value.items():
_add_param_to_feed_dict("{}.{}".format(param_name, key), val)

for param_name, param_value in params.items():
_add_param_to_feed_dict(param_name, param_value)

return feed_dict

另请注意,此问题已被 addressed on develop并将在下一个版本中提供。该修复使用上面代码片段中的方法。

关于machine-learning - Kedro - 如何将嵌套参数直接传递给节点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61452211/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com