gpt4 book ai didi

python - mod.predict 给出的列比预期多

转载 作者:太空宇宙 更新时间:2023-11-03 15:34:28 25 4
gpt4 key购买 nike

我在 IRIS 数据集上使用 MXNet,该数据集有 4 个特征,并将花朵分类为 -“setosa”、“versicolor”、“virginica”。我的训练数据有 89 行。我的标签数据是 89 列的行向量。我将花名编码为数字 -0,1,2,因为 mx.io.NDArrayIter 似乎不接受带有字符串值的 numpy ndarray。然后我尝试使用来预测

re = mod.predict(test_iter)

我得到的结果形状为 14 * 10。为什么我只有 3 个标签时却得到 10 列,以及如何将这些结果映射到我的标签。预测结果如下所示:

[[ 0.11760861 0.12082944 0.1207106 0.09154381 0.09155304 0.09155869 0.09154817 0.09155204 0.09154914 0.09154641] [ 0.1176083 0.12082954 0.12071151 0.09154379 0.09155323 0.09155825 0.0915481 0.09155164 0.09154923 0.09154641] [ 0.11760829 0.1208293 0.12071083 0.09154385 0.09155313 0.09155875 0.09154838 0.09155186 0.09154932 0.09154625] [ 0.11760861 0.12082901 0.12071037 0.09154388 0.09155303 0.09155875 0.09154829 0.09155209 0.09154959 0.09154641] [ 0.11760896 0.12082863 0.12070955 0.09154405 0.09155299 0.09155875 0.09154839 0.09155225 0.09154996 0.09154646] [ 0.1176089 0.1208287 0.1207095 0.09154407 0.09155297 0.09155882 0.09154844 0.09155232 0.09154989 0.0915464 ] [ 0.11760896 0.12082864 0.12070941 0.09154408 0.09155297 0.09155882 0.09154844 0.09155234 0.09154993 0.09154642] [ 0.1176088 0.12082874 0.12070983 0.09154399 0.09155302 0.09155872 0.09154837 0.09155215 0.09154984 0.09154641] [ 0.11760852 0.12082904 0.12071032 0.09154394 0.09155304 0.09155876 0.09154835 0.09155209 0.09154959 0.09154631] [ 0.11760963 0.12082832 0.12070873 0.09154428 0.09155257 0.09155893 0.09154856 0.09155177 0.09155051 0.09154671] [ 0.11760966 0.12082829 0.12070868 0.09154429 0.09155258 0.09155892 0.09154858 0.0915518 0.09155052 0.09154672] [ 0.11760949 0.1208282 0.12070852 0.09154446 0.09155259 0.09155893 0.09154854 0.09155205 0.0915506 0.09154666] [ 0.11760952 0.12082817 0.12070853 0.0915444 0.09155261 0.09155891 0.09154853 0.09155206 0.09155057 0.09154668] [ 0.1176096 0.1208283 0.12070892 0.09154423 0.09155267 0.09155882 0.09154859 0.09155172 0.09155044 0.09154676]]

最佳答案

使用“y = mod.predict(val_iter,num_batch=1)”而不是“y = mod.predict(val_iter)”,那么你只能得到一个批处理标签。例如,如果您的batch_size为10,那么您将只获得10个标签。

关于python - mod.predict 给出的列比预期多,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42641657/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com