gpt4 book ai didi

python - 应如何将指标添加到多头 TensorFlow 估计器中?

转载 作者:行者123 更新时间:2023-11-30 09:58:52 24 4
gpt4 key购买 nike

我之前为 TensorFlow 分类器创建了指标,引用预测['logits']来计算指标。我已将模型从分类器更改为估计器,以实现多目标学习(使用 MultiHead)。然而,这导致 Python 抛出错误,因为现在预测的元素由头名称和原始键对进行键控,例如('label1','logits') 表示名称为“label1”的头。

我希望允许基于配置文件动态生成指标,以便更轻松地训练和测试具有不同标签组合的各种模型。现在的问题是,tf.estimator.add_metricsmetric_fn 参数不采用任何其他参数来允许动态确定或构建指标。

如何生成具有多个头和每个头的自定义指标的估算器?

最佳答案

围绕模型创建构建一个类来保存模型配置,并使用 metric_fn 参数的成员函数。

类模型构建器:
# 构造函数在 self 中存储配置选项
def __init__(自身、标签、other_config_args):
self.labels = 标签
...
# 用于构建多头估计器(多目标)的函数
def build_estimator(self, func_args):
头=[]
对于 self.labels 中的标签:
Heads.append(tf.estimator.MultiClassHead(n_classes=self.nclasses[label], name=label))
头 = tf.estimator.MultiHead(头)
estimator = tf.estimator.DNNEstimator(head=heads,...) # 或您想要的任何类型的估计器
估计器 = tf.estimator.add_metrics(估计器, self.model_metrics)
返回估计器
# 根据模型配置参数将指标添加到估计器的成员函数
def model_metrics(自身、标签、预测、特征):
指标 = {}
for label in self.labels: # 为每个头名称生成一个指标
指标['metric_name'] = metric_func(特征,标签,预测[(标签,'logits')])
返回指标

关于python - 应如何将指标添加到多头 TensorFlow 估计器中?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59888227/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com