gpt4 book ai didi

python - sklearn GMM 分类器的分类输出错误

转载 作者:行者123 更新时间:2023-11-30 08:55:39 26 4
gpt4 key购买 nike

我正在使用 sklearnGMM 工具包构建一个基本的说话人识别器。我有 3 个类(class),每个类(class)都有一个分类器。在测试阶段,应选择概率最高的说话者的 GMM,并且程序应返回每个测试样本的预测类别。我想改变混合组件的数量并在此示例代码中设置n_components=4。如果我使用 4 个混合分量,分类器的输出将是 0、1、2 或 3。如果我使用 3 个混合分量,它将是 0、1 或 2。我感觉分类器返回预测的混合分量而不是整个 GMM。但我希望它能够预测类别:1、2 或 3。

这是我的代码:

import numpy as np
from sklearn.mixture import GMM

#set path
path="path"

class_names = [1,2,3]

covs = ['spherical', 'diag', 'tied', 'full']

training_data = {1: np.loadtxt(path+"/01_train_debug.data"), 2: np.loadtxt(path+"/02_train_debug.data"), 3: np.loadtxt(path+"/03_train_debug.data")}

print "Training models"
models = {}
for c in class_names:
# make a GMM for each of the classes in class_names
models[c] = dict((covar_type,GMM(n_components=4,
covariance_type=covar_type, init_params='wmc',n_init=1, n_iter=20))
for covar_type in covs)


for cov in covs:
for c in class_names:
models[c][cov].fit(training_data[c])

#define test set
test01 = np.loadtxt(path+"/01_test_debug.data")
test02 = np.loadtxt(path+"/02_test_debug.data")
test03 = np.loadtxt(path+"/03_test_debug.data")

testing_data = {1: test01, 2: test02, 3: test03}

probs = {}

print "Calculating Probabilities"

for c in class_names:
probs[c] = {}
for cov in covs:
probs[c][cov] = {}
for p in class_names:
probs[c][cov] = models[p][cov].predict(testing_data[c])


for c in class_names:
print c
for cov in covs:
print " ",cov,
for p in class_names:
print p, probs,
print

我的假设是否正确,或者我的代码中有逻辑错误?sklearn有办法解决这个问题吗?预先感谢您的帮助!

最佳答案

在您的代码中,第一次 models 字典的键是协方差类型,第二次键是类名称。 我误读了您的代码,抱歉.

编辑:如果您想要拟合 GMM 模型下数据的每个样本的可能性,您应该使用 score_samples 方法。 predict 方法不返回概率,而是返回组件分配。

GMM 默认情况下也是非监督模型。如果您想从一堆 GMM 模型中构建一个监督模型,您可能应该将其包装为一个估计器类,该类包装它们并实现拟合/预测 API,以便能够通过交叉验证估计其准确性并调整超参数值通过网格搜索。 Pull request #2468正在实现这样的事情。如果及时合并,它可能会包含在下一个 scikit-learn 版本中(0.15 应该会在 2014 年初发布)。

关于python - sklearn GMM 分类器的分类输出错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/20593553/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com