gpt4 book ai didi

python - 分割数据集以解决分类问题的正确程序是什么?

转载 作者:行者123 更新时间:2023-11-30 09:28:00 27 4
gpt4 key购买 nike

我是机器学习和深度学习的新手。我想在训练前澄清我对 train_test_split 的疑问

我有一个大小为 (302, 100, 5) 的数据集,其中,

(207,100,5) 属于 class 0

(95,100,5) 属于 class 1.

我想使用 LSTM 执行分类(因为,序列数据)

我如何分割我的数据集进行训练,因为类(class)确实没有相等的分布集?

选项 1 :考虑整个数据 [(302,100, 5) - both classes (0 & 1)] ,对其进行洗牌,train_test_split,继续训练。

选项 2: 平均分割两个类数据集 [(95,100,5) - class 0 & (95,100,5) - class 1] ,将其洗牌,train_test_split,继续训练。

训练前更好的分割方式是什么,这样我才能在损失减少、准确性、预测方面获得更好的结果?

如果除了以上两个选项之外还有其他选项,请推荐,

根据评论部分,我包含了部分数据:

X_train:形状 (241 * 100 * 5)

每100*5中的每一行对应1个时间步最后 100 行对应 100 个时间步(以毫秒 (ms) 为单位)

array([[[0.98620635, 0.        , 0.12752912, 0.60897341, 0.46903766],
[0.97345112, 0. , 0.12752912, 0.49205995, 0.38709902],
[0.9566397 , 0. , 0.12752912, 0.45728718, 0.42154812],
...,
[0.28669754, 0.8852459 , 0.12752912, 0.8786213 , 0.80125523],
[0.31559784, 0.8852459 , 0.20968731, 0.89087803, 0.79476987],
[0.34368841, 0.8852459 , 0.12752912, 0.89087803, 0.71066946]],

[[0.97957188, 0.14909194, 0.04159147, 0.50548561, 0.34209531],
[0.9687237 , 0.13964397, 0.04159147, 0.55926067, 0.64613533],
[0.96596236, 0.13553813, 0.04159147, 0.55903796, 0.85299319],
...,
[0.49309139, 0.72396527, 0.04159147, 0.81998825, 0.12362443],
[0.52072591, 0.70872926, 0.04159147, 0.82361951, 0.89639432],
[0.54441507, 0.71835207, 0.04159147, 0.84964602, 1. ]],

[[0.48151381, 0.875 , 0.16666667, 0.90637286, 0.62737926],
[0.53325374, 0.8625 , 0.33333333, 0.87881677, 0.5321154 ],
[0.57506452, 0.81859091, 0.16666667, 0.84915758, 0.3552661 ],
...,
[0.34456041, 0.92993213, 0.33333333, 0.92953899, 0.78782408],
[0.39496018, 0.90523485, 0.33333333, 0.9117954 , 0.54579383],
[0.44187985, 0.8625 , 0.33333333, 0.84163194, 0.25789356]],

...,

[[0.16368355, 0. , 0.15313225, 0.40101906, 0.36784741],
[0.15679684, 0. , 0.15313225, 0.4435126 , 0.67351994],
[0.15544309, 0.06132052, 0.15313225, 0.40101906, 0.36611345],
...,
[0.43936628, 0.68292683, 0.15313225, 0.82305329, 0.36784741],
[0.49751546, 0.68292683, 0.07764888, 0.84141109, 0.42828833],
[0.53288488, 0.68292683, 0.15313225, 0.85959823, 0.36784741]],

[[0.9418247 , 0.30821318, 0.03072816, 0.744977 , 0.93769733],
[0.9537216 , 0.28989357, 0.03072816, 0.74576381, 0.98468743],
[0.96455286, 0.21736423, 0.03072816, 0.74182977, 1. ],
...,
[0.36273884, 0.60113245, 0.06145633, 0.85409181, 0.32277415],
[0.38774614, 0.57789971, 0.05844559, 0.82937631, 0. ],
[0.41546859, 0.57789971, 0.03072816, 0.79315883, 0.31256578]],

[[0.97868688, 0.06451613, 0.00411829, 0.64705259, 0.69827586],
[0.97999663, 0.06451613, 0.02256676, 0.66812232, 0.75195925],
[0.97143037, 0.02476377, 0.02256676, 0.66317859, 0.78487461],
...,
[0.50336862, 0.73867709, 0.02256676, 0.84921606, 0.1226489 ],
[0.54003486, 0.72043011, 0.02256676, 0.82679269, 0.20297806],
[0.57594039, 0.70967742, 0.02256676, 0.83350205, 0. ]]])

Y_train:形状 (241,)

[1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0.
1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 1.
0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0.
0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.
1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1.
0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0.
0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 0. 1.
0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
1. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1.
1.]

供引用,正如您在上面看到的,X-train 数据很大,我无法包含整个 X_train 数据的完整集。因此,我在这里仅提供数据的一个片段,以便更好地理解 1 个片段 (i.e X_train[0] : shape- (100*5)) 的数据是什么样子。剩余的 240 或多或少如下所示

array([[9.86206354e-01, 0.00000000e+00, 1.27529123e-01, 2.29139335e-02,
6.08973407e-01, 4.69037657e-01],
[9.73451120e-01, 0.00000000e+00, 1.27529123e-01, 2.60807671e-02,
4.92059955e-01, 3.87099024e-01],
[9.56639704e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.57287179e-01, 4.21548117e-01],
[9.34897700e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.84177685e-01, 4.69037657e-01],
[9.18030989e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,
4.86406180e-01, 4.08577406e-01],
[9.02168015e-01, 0.00000000e+00, 1.27529123e-01, 2.64020795e-02,
4.84920517e-01, 4.04184100e-01],
[8.82551572e-01, 0.00000000e+00, 1.27529123e-01, 2.56783096e-02,
4.51195959e-01, 3.78661088e-01],
[8.69975342e-01, 0.00000000e+00, 1.27529123e-01, 2.40477851e-02,
4.70286733e-01, 4.23640167e-01],
[8.41027241e-01, 0.00000000e+00, 1.27529123e-01, 1.75387576e-02,
5.04754123e-01, 4.34728033e-01],
[8.28189535e-01, 5.28763040e-01, 1.27529123e-01, 6.89133486e-03,
4.98662903e-01, 4.58368201e-01],
[8.21784739e-01, 8.21162444e-01, 1.27529123e-01, 1.06196483e-02,
5.87431288e-01, 5.72594142e-01],
[8.26651597e-01, 9.96721311e-01, 1.27529123e-01, 1.75044480e-02,
6.89050661e-01, 5.40376569e-01],
[8.42115326e-01, 1.00000000e+00, 1.27529123e-01, 1.71205069e-02,
8.35388501e-01, 4.69037657e-01],
[8.64071009e-01, 9.26875310e-01, 1.27529123e-01, 1.34068975e-02,
1.00000000e+00, 4.65062762e-01],
[8.79579724e-01, 7.60158967e-01, 1.27529123e-01, 4.65303975e-03,
9.61744169e-01, 3.65481172e-01],
[9.03630040e-01, 7.61549925e-01, 1.27529123e-01, 4.21518348e-03,
9.22076957e-01, 3.78033473e-01],
[9.18435858e-01, 6.72429210e-01, 1.27529123e-01, 2.70229205e-03,
9.39979201e-01, 5.03138075e-01],
[9.29983046e-01, 6.85345256e-01, 1.27529123e-01, 9.05120794e-04,
8.53736443e-01, 5.52510460e-01],
[9.48081232e-01, 5.78539493e-01, 1.27529123e-01, 6.96485550e-03,
8.84415391e-01, 3.04602510e-01],
[9.48112160e-01, 5.55091903e-01, 1.27529123e-01, 1.10493356e-02,
8.19046204e-01, 4.78661088e-01],
[9.61281634e-01, 5.08693492e-01, 1.27529123e-01, 9.36162843e-03,
8.23651761e-01, 3.21548117e-01],
[9.72179346e-01, 4.91803279e-01, 1.27529123e-01, 9.82725917e-03,
7.57391175e-01, 4.96025105e-01],
[9.84752763e-01, 4.91803279e-01, 1.27529123e-01, 7.04491131e-03,
7.59322538e-01, 3.95397490e-01],
[9.90300024e-01, 4.91803279e-01, 1.27529123e-01, 8.19346712e-03,
7.64819492e-01, 4.69037657e-01],
[9.88306609e-01, 3.77049180e-01, 1.27529123e-01, 8.62642201e-03,
7.93492795e-01, 4.16945607e-01],
[9.91084457e-01, 3.93442623e-01, 1.27529123e-01, 9.16557339e-03,
7.10741346e-01, 4.72175732e-01],
[1.00000000e+00, 3.78936910e-01, 1.27529123e-01, 1.16538387e-02,
6.93359085e-01, 4.76987448e-01],
[9.98925974e-01, 3.93442623e-01, 1.27529123e-01, 1.21309060e-02,
7.16609716e-01, 3.46025105e-01],
[9.92838888e-01, 3.32141083e-01, 1.27529123e-01, 1.19315833e-02,
7.31540633e-01, 4.16527197e-01],
[9.90637415e-01, 3.36910084e-01, 1.27529123e-01, 9.95632874e-03,
7.12524142e-01, 4.15481172e-01],
[9.90761125e-01, 3.38301043e-01, 1.27529123e-01, 6.59235091e-03,
6.86970732e-01, 4.37656904e-01],
[9.90274720e-01, 3.27868852e-01, 2.10913550e-01, 5.68396253e-03,
7.09181399e-01, 4.99372385e-01],
[9.83015202e-01, 3.27868852e-01, 1.27529123e-01, 2.14974358e-02,
7.31392067e-01, 6.41631799e-01],
[9.77392028e-01, 2.85245902e-01, 1.47762109e-01, 2.52861995e-02,
7.09478532e-01, 6.07112971e-01],
[9.75300207e-01, 2.78688525e-01, 1.27529123e-01, 2.91468501e-02,
6.70257020e-01, 6.28242678e-01],
[9.74917831e-01, 2.71733731e-01, 1.27529123e-01, 3.58780734e-02,
6.70257020e-01, 5.72594142e-01],
[9.64950755e-01, 2.62295082e-01, 1.27529123e-01, 3.92992339e-02,
6.36383895e-01, 6.67991632e-01],
[9.63159774e-01, 2.62295082e-01, 1.27529123e-01, 4.82932591e-02,
6.93581934e-01, 5.46443515e-01],
[9.54983679e-01, 2.90511674e-01, 1.27529123e-01, 4.90627752e-02,
6.59708810e-01, 7.40376569e-01],
[9.57595643e-01, 3.11475410e-01, 1.27529123e-01, 4.72492660e-02,
6.49977715e-01, 5.61297071e-01],
[9.51511369e-01, 2.95081967e-01, 1.27529123e-01, 1.82576261e-02,
6.64314366e-01, 5.22384937e-01],
[9.48528275e-01, 2.95081967e-01, 1.27529123e-01, 3.89659403e-03,
6.29846977e-01, 3.20711297e-01],
[9.47085931e-01, 2.95081967e-01, 1.27529123e-01, 6.86682798e-03,
6.48417769e-01, 4.38284519e-01],
[9.38153518e-01, 2.95081967e-01, 1.27529123e-01, 5.73951146e-03,
7.04130144e-01, 5.32635983e-01],
[9.38114156e-01, 2.95081967e-01, 1.27529123e-01, 2.05955826e-02,
6.85782202e-01, 5.47280335e-01],
[9.35597786e-01, 2.95081967e-01, 1.27529123e-01, 2.91141743e-02,
6.69142772e-01, 7.13807531e-01],
[9.29311077e-01, 2.72826627e-01, 1.27529123e-01, 2.91141743e-02,
6.81622344e-01, 5.72594142e-01],
[9.25495753e-01, 2.23646299e-01, 1.27529123e-01, 2.65507546e-02,
6.35566781e-01, 6.41004184e-01],
[9.18525829e-01, 2.08643815e-03, 1.27529123e-01, 2.37618715e-02,
6.09641955e-01, 5.02928870e-01],
[8.91801693e-01, 0.00000000e+00, 1.27529123e-01, 9.27013608e-03,
5.26073392e-01, 4.21338912e-01],
[8.77693149e-01, 0.00000000e+00, 1.27529123e-01, 8.13628440e-03,
4.22522656e-01, 3.44560669e-01],
[8.61894841e-01, 0.00000000e+00, 1.27529123e-01, 1.49639014e-02,
4.52755906e-01, 3.65481172e-01],
[8.44254943e-01, 0.00000000e+00, 1.27529123e-01, 2.29515107e-02,
4.59069975e-01, 3.76150628e-01],
[8.21183060e-01, 0.00000000e+00, 1.27529123e-01, 3.97583295e-02,
4.60852771e-01, 2.60460251e-01],
[8.04116726e-01, 0.00000000e+00, 1.27529123e-01, 5.89292454e-02,
4.26905363e-01, 1.97907950e-01],
[7.81311943e-01, 0.00000000e+00, 1.27529123e-01, 8.53656345e-02,
4.37379290e-01, 1.00836820e-01],
[7.60863270e-01, 0.00000000e+00, 1.27529123e-01, 1.03087377e-01,
4.37379290e-01, 6.98744770e-02],
[7.41227145e-01, 0.00000000e+00, 1.27529123e-01, 1.14206966e-01,
4.27128213e-01, 1.58368201e-01],
[7.26694052e-01, 0.00000000e+00, 1.27529123e-01, 1.17776801e-01,
4.37379290e-01, 0.00000000e+00],
[7.08716764e-01, 0.00000000e+00, 1.27529123e-01, 1.17288297e-01,
4.48596048e-01, 2.18619247e-01],
[6.90483621e-01, 0.00000000e+00, 1.27529123e-01, 1.08491961e-01,
4.58549993e-01, 1.26987448e-01],
[6.67451099e-01, 0.00000000e+00, 1.27529123e-01, 8.38217010e-02,
4.99628584e-01, 3.55020921e-01],
[6.51610618e-01, 0.00000000e+00, 1.27529123e-01, 4.32889541e-02,
5.10919626e-01, 4.83054393e-01],
[6.31195684e-01, 0.00000000e+00, 1.27529123e-01, 1.29200275e-02,
5.21170703e-01, 4.97907950e-01],
[6.14317726e-01, 0.00000000e+00, 2.26241570e-01, 9.32895259e-04,
4.98960036e-01, 4.69037657e-01],
[5.98165158e-01, 0.00000000e+00, 5.90435316e-01, 0.00000000e+00,
4.61892735e-01, 5.03556485e-01],
[5.68221755e-01, 0.00000000e+00, 6.33353771e-01, 1.61745413e-03,
4.25122567e-01, 4.69037657e-01],
[5.35292447e-01, 0.00000000e+00, 1.00000000e+00, 8.99402522e-03,
3.58490566e-01, 5.10041841e-01],
[5.10766973e-01, 0.00000000e+00, 3.93010423e-01, 3.39894098e-02,
3.27068786e-01, 6.15690377e-01],
[4.78939807e-01, 0.00000000e+00, 5.32188841e-01, 5.98114931e-02,
3.27068786e-01, 6.22175732e-01],
[4.47053597e-01, 0.00000000e+00, 4.31023912e-01, 8.44245703e-02,
3.24023176e-01, 6.76150628e-01],
[4.13654754e-01, 0.00000000e+00, 5.32188841e-01, 1.07209434e-01,
2.90298618e-01, 7.08577406e-01],
[3.80151882e-01, 0.00000000e+00, 7.97057020e-01, 1.21122807e-01,
1.19150201e-01, 4.95397490e-01],
[3.28235926e-01, 0.00000000e+00, 3.56223176e-01, 1.23820198e-01,
0.00000000e+00, 6.65271967e-01],
[2.83452966e-01, 0.00000000e+00, 2.28694053e-01, 1.22658572e-01,
2.65933739e-02, 5.55648536e-01],
[2.38616587e-01, 0.00000000e+00, 2.28694053e-01, 1.22990232e-01,
9.41910563e-02, 4.92887029e-01],
[1.82964031e-01, 0.00000000e+00, 5.19926426e-01, 1.30564491e-01,
8.97340663e-02, 4.94142259e-01],
[1.43835174e-01, 0.00000000e+00, 5.25444513e-01, 1.64135650e-01,
1.14618927e-01, 7.40585774e-01],
[1.04402664e-01, 0.00000000e+00, 1.55119559e-01, 2.41378071e-01,
1.98261774e-01, 6.50418410e-01],
[7.96438281e-02, 0.00000000e+00, 7.11220110e-02, 3.27145618e-01,
2.89110088e-01, 7.45188285e-01],
[6.36065353e-02, 0.00000000e+00, 0.00000000e+00, 4.11129065e-01,
4.05140395e-01, 6.88912134e-01],
[4.11672585e-02, 0.00000000e+00, 2.52605763e-01, 5.62182942e-01,
4.54315852e-01, 1.00000000e+00],
[2.87063044e-02, 0.00000000e+00, 1.27529123e-01, 6.81786323e-01,
4.59515674e-01, 9.32217573e-01],
[1.70269716e-02, 1.58966716e-03, 1.27529123e-01, 7.33474602e-01,
4.37453573e-01, 6.07322176e-01],
[3.30361486e-03, 6.37853949e-01, 1.27529123e-01, 8.06276376e-01,
4.69692468e-01, 7.54602510e-01],
[0.00000000e+00, 7.89369101e-01, 1.27529123e-01, 8.85843682e-01,
5.10919626e-01, 8.70502092e-01],
[5.13114648e-03, 8.19672131e-01, 1.27529123e-01, 9.60932765e-01,
5.99316595e-01, 8.79288703e-01],
[2.16829598e-02, 8.36065574e-01, 1.27529123e-01, 9.99121020e-01,
7.28866439e-01, 8.56903766e-01],
[4.27951674e-02, 8.36065574e-01, 1.27529123e-01, 1.00000000e+00,
8.67181697e-01, 7.88912134e-01],
[7.02334461e-02, 8.36065574e-01, 1.27529123e-01, 9.93500775e-01,
8.46308127e-01, 9.78451883e-01],
[9.73680733e-02, 8.36065574e-01, 1.27529123e-01, 9.87896869e-01,
8.66364582e-01, 8.59414226e-01],
[1.23611427e-01, 8.36065574e-01, 1.27529123e-01, 9.69613102e-01,
8.35685634e-01, 9.17991632e-01],
[1.52157471e-01, 8.68852459e-01, 1.27529123e-01, 9.22226597e-01,
7.96686971e-01, 9.65062762e-01],
[1.77979087e-01, 8.68852459e-01, 1.27529123e-01, 8.61132577e-01,
8.29594414e-01, 8.14225941e-01],
[2.03010647e-01, 8.84252360e-01, 1.27529123e-01, 8.13277174e-01,
8.29594414e-01, 9.11506276e-01],
[2.32490138e-01, 8.85245902e-01, 1.27529123e-01, 7.59549923e-01,
8.41851137e-01, 9.52301255e-01],
[2.58952796e-01, 8.85245902e-01, 1.27529123e-01, 6.97804020e-01,
8.55667806e-01, 8.68200837e-01],
[2.86697538e-01, 8.85245902e-01, 1.27529123e-01, 6.25149288e-01,
8.78621304e-01, 8.01255230e-01],
[3.15597842e-01, 8.85245902e-01, 2.09687308e-01, 5.51940700e-01,
8.90878027e-01, 7.94769874e-01],
[3.43688409e-01, 8.85245902e-01, 1.27529123e-01, 4.75801089e-01,
8.90878027e-01, 7.10669456e-01]])

最佳答案

TLDR:两者都尝试一下!

<小时/>我以前也遇到过类似的情况,我的数据集不平衡。我用过 train_test_splitKFold才能通过。

但是,一旦我偶然发现了处理不平衡数据集的问题,并遇到了过度平衡和欠平衡的技术。为此,我建议使用该库:imblearn

您会在那里找到各种技术来处理其中一个类的数量超过另一个类的情况。我个人用过SMOTE并在此类案例中取得了相对较好的成功。

<小时/>其他引用:

https://www.analyticsvidhya.com/blog/2017/03/imbalanced-classification-problem/

https://towardsdatascience.com/handling-imbalanced-datasets-in-machine-learning-7a0e84220f28

关于python - 分割数据集以解决分类问题的正确程序是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57142772/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com