gpt4 book ai didi

tensorflow - 为什么在 Keras 指标函数中使用 axis=-1 ?

转载 作者:行者123 更新时间:2023-12-03 10:02:47 26 4
gpt4 key购买 nike

keras 版本:2.0.8

在一些 Keras 度量函数和损失函数中,使用 axis=-1 作为参数。

例如:

def binary_accuracy(y_true, y_pred):
return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)

就我而言:

y_true 的形状:(4,256,256,2)

y_pred 的形状:(4,256,256,2)

因此, binary_accuracy(y_true, y_pred) 应该返回一个 shape=(4,256,256) 而不是标量张量的张量。

但是当使用 binary_accuracy 作为度量函数时:
model.compile(optimizer=adam, loss=keras.losses.binary_crossentropy, metrics=[binary_accuracy])

日志仍然将 binary_accuracy 打印为标量,这让我很困惑。

keras 是否对 binary_accuracy 函数的返回做了一些特殊处理?

Epoch 11/300

0s - loss: 0.4158 - binary_accuracy: 0.9308 - val_loss: 0.4671 - val_binary_accuracy: 0.7767

最佳答案

这就是您要找的东西,里面 training_utils.py :

def weighted(y_true, y_pred, weights, mask=None):
"""Wrapper function.
# Arguments
y_true: `y_true` argument of `fn`.
y_pred: `y_pred` argument of `fn`.
weights: Weights tensor.
mask: Mask tensor.
# Returns
Scalar tensor.
"""
# score_array has ndim >= 2
score_array = fn(y_true, y_pred)
if mask is not None:
# Cast the mask to floatX to avoid float64 upcasting in Theano
mask = K.cast(mask, K.floatx())
# mask should have the same shape as score_array
score_array *= mask
# the loss per batch should be proportional
# to the number of unmasked samples.
score_array /= K.mean(mask) + K.epsilon()

# apply sample weighting
if weights is not None:
# reduce score_array to same ndim as weight array
ndim = K.ndim(score_array)
weight_ndim = K.ndim(weights)
score_array = K.mean(score_array,
axis=list(range(weight_ndim, ndim)))
score_array *= weights
score_array /= K.mean(K.cast(K.not_equal(weights, 0), K.floatx()))
return K.mean(score_array)
return weighted

度量函数由 score_array = fn(y_true, y_pred) 调用(它是一个嵌套函数, fn 在外部函数中定义)。该数组在最后一行中取平均值 return K.mean(score_array) .这就是为什么您看到的是标量指标而不是张量。中间的线只是在必要时引入掩码和权重。

关于tensorflow - 为什么在 Keras 指标函数中使用 axis=-1 ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46298110/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com