gpt4 book ai didi

python - 如何删除这些 keras 指标定义的重复?

转载 作者:行者123 更新时间:2023-12-01 06:27:12 25 4
gpt4 key购买 nike

Keras 提供了准确度、精确度和召回率指标,可用于评估模型,但这些指标只能评估整个 y_truey_pred 。我希望它仅评估数据的子集。 y_true[..., 0:20]我的数据中包含我想要评估的二进制值,但是 y_true[..., 20:40]包含另一种数据。

因此,我修改了精确率和召回率类别,以仅对数据的前 20 个 channel 进行评估。我通过对这些指标进行子类化并要求他们在评估之前对数据进行切片来做到这一点。

from tensorflow import keras as kr

class SliceBinaryAccuracy(kr.metrics.BinaryAccuracy):
"""Slice data before evaluating accuracy. To be used as Keras metric"""

def __init__(self, channels, *args, **kwargs):
self.channels = channels
super().__init__(*args, **kwargs)

def _slice(self, y):
return y[..., : self.channels]

def __call__(self, y_true, y_pred, *args, **kwargs):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
return super().__call__(y_true, y_pred, *args, **kwargs)

def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
super().update_state(y_true, y_pred, sample_weight=sample_weight)


class SlicePrecision(kr.metrics.Precision):
"""Slice data before evaluating precision. To be used as Keras metric"""

def __init__(self, channels, *args, **kwargs):
self.channels = channels
super().__init__(*args, **kwargs)

def _slice(self, y):
return y[..., : self.channels]

def __call__(self, y_true, y_pred, *args, **kwargs):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
return super().__call__(y_true, y_pred, *args, **kwargs)

def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
super().update_state(y_true, y_pred, sample_weight=sample_weight)


class SliceRecall(kr.metrics.Recall):
"""Slice data before evaluating recall. To be used as Keras metric"""

def __init__(self, channels, *args, **kwargs):
self.channels = channels
super().__init__(*args, **kwargs)

def _slice(self, y):
return y[..., : self.channels]

def __call__(self, y_true, y_pred, *args, **kwargs):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
return super().__call__(y_true, y_pred, *args, **kwargs)

def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
super().update_state(y_true, y_pred, sample_weight=sample_weight)

以上类的使用方式是这样的:

model.compile('adam', loss='mse', metrics=[SliceBinaryAccuracy(20), SlicePrecision(20), SliceRecall(20)])

代码可以运行,但我发现代码很长。我发现这 3 个指标有很多重复,我如何将这些类概括为单个类或其他更好的设计?如果可能,请给出示例代码。

最佳答案

我同意这些类中有太多重复,它们之间的唯一区别是它们子类化的指标。我认为这是应用某种工厂模式的一个很好的案例。我正在分享我创建的一个小函数,用于动态子类化指标。

def MetricFactory(cls, channels):
'''Takes a keras metric class and channels value and returns the instantiated subclassed metric'''

class DynamicMetric(cls):
def __init__(self, channels, *args, **kwargs):
self.channels = channels
super().__init__(*args, **kwargs)

def _slice(self, y):
return y[..., : self.channels]

def __call__(self, y_true, y_pred, *args, **kwargs):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
return super().__call__(y_true, y_pred, *args, **kwargs)

def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self._slice(y_true)
y_pred = self._slice(y_pred)
super().update_state(y_true, y_pred, sample_weight=sample_weight)

x = DynamicMetric(channels)
return x

然后您可以按如下方式使用它:

metrics = [MetricFactory(kr.metrics.BinaryAccuracy, 20), MetricFactory(kr.metrics.Precision, 20), MetricFactory(kr.metrics.Recall, 20)]
model.compile('adam', loss='mse', metrics=metrics)

由于覆盖的方法对于您要子类化的三个指标完全相同,因此该函数可以将它们直接注入(inject)到新类中。为了简单起见,该函数返回实例化的子类,但您也可以返回新类。值得注意的是,如果您必须将要覆盖的方法作为参数传递,并且可能需要在此 thread 行中使用元类或奇妙的黑魔法,则这种特殊方法将不起作用。 。

关于python - 如何删除这些 keras 指标定义的重复?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60076928/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com