gpt4 book ai didi

python - Numpy apply_along_axis 推断出错误的数据类型

转载 作者:行者123 更新时间:2023-11-28 18:19:30 25 4
gpt4 key购买 nike

我在使用 NumPy 时遇到以下问题:

代码:

import numpy as np
get_label = lambda x: 'SMALL' if x.sum() <= 10 else 'BIG'
arr = np.array([[1, 2], [30, 40]])
print np.apply_along_axis(get_label, 1, arr)
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label, 1, arr)

输出:

['SMALL' 'BIG']
['BIG' 'SMA'] # String 'SMALL' is stripped!

我可以看到 NumPy 以某种方式从函数返回的第一个值推断出数据类型。我想出了以下解决方法 - 从具有明确声明的 dtype 而不是字符串的函数返回 NumPy 数组,并 reshape 结果:

def get_label_2(x):
if x.sum() <= 10:
return np.array(['SMALL'], dtype='|S5')
else:
return np.array(['BIG'], dtype='|S5')
arr = np.array([[30, 40], [1, 2]])
print np.apply_along_axis(get_label_2, 1, arr).reshape(arr.shape[0])

你知道这个问题更优雅的解决方案吗?

最佳答案

你可以使用np.where:

arr1 = np.array([[1, 2], [30, 40]])
arr2 = np.array([[30, 40], [1, 2]])

print(np.where(arr1.sum(axis=1)<=10,'SMALL','BIG'))
print(np.where(arr2.sum(axis=1)<=10,'SMALL','BIG'))
['SMALL' 'BIG']
['BIG' 'SMALL']

在函数中:

def get_label(x, threshold, axis=1, label1='SMALL', label2='BIG'):
return np.where(x.sum(axis=axis) <= threshold, label1, label2)

关于python - Numpy apply_along_axis 推断出错误的数据类型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46096748/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com