gpt4 book ai didi

python - 如何使用分类焦点损失使用 keras 进行一次热编码?

转载 作者:行者123 更新时间:2023-12-04 03:41:58 24 4
gpt4 key购买 nike

我正在研究癫痫发作预测。我有不平衡的数据集,我想通过使用焦点损失使其平衡。我有 2 个类的单热编码向量。我找到了下面的焦点损失代码,但我不知道如何获得 y_predmodel.fit_generator 之前用于焦点损失代码.y_pred是模型的输出。那么如何在拟合模型之前在焦点损失代码中使用它?
焦点损失代码:

def categorical_focal_loss(gamma=2.0, alpha=0.25):
"""
Implementation of Focal Loss from the paper in multiclass classification
Formula:
loss = -alpha*((1-p)^gamma)*log(p)
Parameters:
alpha -- the same as wighting factor in balanced cross entropy
gamma -- focusing parameter for modulating factor (1-p)
Default value:
gamma -- 2.0 as mentioned in the paper
alpha -- 0.25 as mentioned in the paper
"""
def focal_loss(y_true, y_pred):
# Define epsilon so that the backpropagation will not result in NaN
# for 0 divisor case
epsilon = K.epsilon()
# Add the epsilon to prediction value
#y_pred = y_pred + epsilon
# Clip the prediction value
y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
# Calculate cross entropy
cross_entropy = -y_true*K.log(y_pred)
# Calculate weight that consists of modulating factor and weighting factor
weight = alpha * y_true * K.pow((1-y_pred), gamma)
# Calculate focal loss
loss = weight * cross_entropy
# Sum the losses in mini_batch
loss = K.sum(loss, axis=1)
return loss

return focal_loss
我的代码:
history=model.fit_generator(generate_arrays_for_training(indexPat, train_data, start=0,end=100)
validation_data=generate_arrays_for_training(indexPat, test_data, start=0,end=100)
steps_per_epoch=int((len(train_data)/2)),
validation_steps=int((len(test_data)/2)),
verbose=2,epochs=65, max_queue_size=2, shuffle=True)
preictPrediction=model.predict_generator(generate_arrays_for_predict(indexPat, filesPath_data), max_queue_size=4, steps=len(filesPath_data))
y_pred1=np.argmax(preictPrediction,axis=1)
y_pred=list(y_pred1)

最佳答案

为了社区的利益,来自评论部分。

This is not specific to focal loss, all keras loss functions takey_true and y_pred, you do not need to worry where those parameters arecoming from, they are fed by keras automatically.

关于python - 如何使用分类焦点损失使用 keras 进行一次热编码?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65875860/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com