gpt4 book ai didi

tensorflow - 批量数据是否应该平衡?

转载 作者:行者123 更新时间:2023-11-30 08:37:50 24 4
gpt4 key购买 nike

我正在训练一个深度学习模型,通过输入推文内容来预测三种情绪(快乐、悲伤、愤怒)。

我遇到的一个问题是我的模型可以很好地学习悲伤、欢乐,但对于欢乐却很糟糕。

Confusion matrix in three emotion

我认为原因是我的训练数据集不平衡。

快乐时的数据量:196952,悲伤:29407,愤怒:42420

因此,在训练模型时,批量大小包含太多的喜悦数据集,这使得模型只能猜测答案是喜悦而不是其他答案。

我想通过平衡每批中的数据来解决这个问题。也就是说批量大小为 128,我们随机选择相同数量的三个情感数据。防止模型被快乐的数据所主导。

Question is: Should the data in batch be balanced?

另一个问题是,我随机选择的数据集,这是否违反了纪元的定义。

因为纪元意味着读取所有训练数据集。当随机选择时,也许某些数据集在某些时期不会被选择。或者只是训练更多的纪元就能解决这个问题?

如果我觉得有什么不对的地方,欢迎指出。谢谢!

最佳答案

一种可能的方法是为分类器添加权重。

来自: https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#class_weights

The goal is to identify fraudulent transactions, but you don't havevery many of those positive samples to work with, so you would want tohave the classifier heavily weight the few examples that areavailable. You can do this by passing Keras weights for each classthrough a parameter. These will cause the model to "pay moreattention" to examples from an under-represented class.

由于您的问题是多类的,您可以这样做 https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html

我用类似的东西来做到这一点:

from sklearn.utils import class_weight

class_weights = dict (enumerate (class_weight.compute_class_weight (
class_weight = 'balanced',
classes = available_labels,
y = self.dataset.get_split (df, 'train')['label']
)))

然后:

history = model.fit (
...
class_weight = class_weights
)

根据我的经验,这种方法可以实现更好的解决方案,同时使训练速度更快。

此外,我认为保持大批量并确保数据是随机的也是处理不平衡数据的其他好方法。

关于tensorflow - 批量数据是否应该平衡?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48200136/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com