gpt4 book ai didi

keras - 如何将权重传递给keras中的均方误差

转载 作者:行者123 更新时间:2023-12-03 01:07:10 25 4
gpt4 key购买 nike

我正在尝试解决一个回归问题,这是一个带有 8 个标签的多标签问题,我使用均方误差损失,但数据集不平衡,我想将权重传递给损失函数。目前我正在编译模型是这样的。

model.compile(loss='mse', optimizer=Adam(lr=0.0001), metrics=['mse', 'acc'])

有人可以建议是否可以将权重添加到均方误差上,如果可以,我该怎么做?

提前致谢

标签看起来像这样

enter image description here

#
model = Sequential()    
model.add(effnet)
model.add(GlobalAveragePooling2D())
model.add(Dropout(0.5))
model.add(Dense(8,name = 'nelu', activation=elu))
model.compile(loss=custom_mse(class_weights),
optimizer=Adam(lr=0.0001), metrics=['mse', 'acc'])

最佳答案

import keras
from keras.models import Sequential
from keras.layers import Conv2D, Flatten, Dense, Conv1D, LSTM, TimeDistributed
import keras.backend as K


# custom loss function
def custom_mse(class_weights):
def loss_fixed(y_true, y_pred):
"""
:param y_true: A tensor of the same shape as `y_pred`
:param y_pred: A tensor resulting from a sigmoid
:return: Output tensor.
"""
# print('y_pred:', K.int_shape(y_pred))
# print('y_true:', K.int_shape(y_true))
y_pred = K.reshape(y_pred, (8, 1))
y_pred = K.dot(class_weights, y_pred)
# calculating mean squared error
mse = K.mean(K.square(y_pred - y_true), axis=-1)
# print('mse:', K.int_shape(mse))
return mse

model = Sequential()
model.add(Conv1D(8, (1), input_shape=(28, 28)))
model.add(Flatten())
model.add(Dense(8))

# custom class weights
class_weights = K.variable([[0.25, 1., 2., 3., 2., 0.6, 0.5, 0.15]])
# print('class_weights:', K.int_shape(class_weights))

model.compile(optimizer='adam', loss=custom_mse(class_weights), metrics=['accuracy'])

这是一个基于您的问题陈述的自定义损失函数的小型实现

  • 您可以从 losses.py 找到有关 keras 损失函数的更多信息并查看其官方文档 here

  • Keras 本身不处理低级运算,例如张量积、卷积等。相反,它依赖于一个专门的、优化良好的张量操作库来完成此操作,充当 Keras 的“后端引擎”。有关 keras backend 的更多信息可以在这里找到,也可以从here查看其官方文档。

  • 使用 K.int_shape(tensor_name) 查找 dimensions of a tensor .

关于keras - 如何将权重传递给keras中的均方误差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57840750/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com