gpt4 book ai didi

python - 在 Numpy 数组的列上应用条件/过滤器

转载 作者:行者123 更新时间:2023-11-30 09:34:06 26 4
gpt4 key购买 nike

我有 2 个 Numpy 数组,第一个有 210 行,第二个有 30 行,都包含 4 列,我想对两个数组的第四列应用条件/过滤器,其中仅包含 0 或 1。所以,我想检测第一个数组的 0 作为 Train_Safe,第一个数组的 1 作为 Train_Cracked,第二个数组的 0 作为 Test_Safe,第二个数组的 1 作为 Test_Cracked,并使用 Matplotlib 将这些值绘制在 3D 散点图上,我尝试使用此代码:

    for i in X_train_merge[0:, 3]:
if i == 0:
x_vals_train_0 = X_train_merge[0:, 0:1]
y_vals_train_0 = X_train_merge[0:, 1:2]
z_vals_train_0 = X_train_merge[0:, 2:3]
elif i == 1:
x_vals_train_1 = X_train_merge[0:, 0:1]
y_vals_train_1 = X_train_merge[0:, 1:2]
z_vals_train_1 = X_train_merge[0:, 2:3]
for j in X_test_merge[0:, 3]:
if j == 0:
x_vals_test_0 = X_test_merge[0:, 0:1]
y_vals_test_0 = X_test_merge[0:, 1:2]
z_vals_test_0 = X_test_merge[0:, 2:3]
elif j == 1:
x_vals_test_1 = X_test_merge[0:, 0:1]
y_vals_test_1 = X_test_merge[0:, 1:2]
z_vals_test_1 = X_test_merge[0:, 2:3]

ax.scatter(x_vals_train_0, y_vals_train_0, z_vals_train_0, c='g', marker='o', label="Train_Safe")
ax.scatter(x_vals_train_1, y_vals_train_1, z_vals_train_1, c='b', marker='o', label="Train_Cracked")
ax.scatter(x_vals_test_0, y_vals_test_0, z_vals_test_0, c='black', marker='*', label="Test_Safe")
ax.scatter(x_vals_test_1, y_vals_test_1, z_vals_test_1, c='brown', marker='*', label="Test_Cracked")

它绘制/给出所有数据点,而不将其分解/划分为 Train_Safe、Train_Cracked、Test_Safe 和 Test_Cracked。对于此任务的任何建议/解决方案。提前致谢。

最佳答案

提供玩具数据是有礼貌的

import numpy as np
a = np.random.rand(10, 4)

a[:, 3] = a[:, 3] > 0.5

a

np.array([[ 0.93011873, 0.80167023, 0.46502502, 0. ],
[ 0.48754049, 0.331763 , 0.19391945, 1. ],
[ 0.17976529, 0.73625689, 0.6550934 , 0. ],
[ 0.17797159, 0.89597292, 0.67507392, 1. ],
[ 0.89972382, 0.86131195, 0.85239512, 1. ],
[ 0.59199271, 0.14223656, 0.12101887, 1. ],
[ 0.71962168, 0.89132196, 0.61149278, 0. ],
[ 0.63606024, 0.04821054, 0.49971309, 1. ],
[ 0.18976505, 0.49880633, 0.93362872, 1. ],
[ 0.00154421, 0.79748799, 0.46080879, 0. ]])

那么np.where就是工具:

ts = a[np.where(a[:, -1] == 0), :-1].T

tc = a[np.where(a[:, -1] == 1), :-1].T

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(*ts, c='g', marker='o', label="Train_Safe")
ax.scatter(*tc, c='b', marker='o', label="Train_Cracked")
fig.show()

enter image description here

关于python - 在 Numpy 数组的列上应用条件/过滤器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47966665/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com