gpt4 book ai didi

python - 高效的概率树分支

转载 作者:行者123 更新时间:2023-12-04 03:34:45 26 4
gpt4 key购买 nike

例如,我有一个分支和概率数组,如下所示:

paths = np.array([
[1, 0, 1.0],
[2, 0, 0.4],
[2, 1, 0.6],
[3, 1, 1.0],
[5, 1, 0.25],
[5, 2, 0.5],
[5, 4, 0.25],
[6, 0, 0.7],
[6, 5, 0.2],
[6, 2, 0.1]])
列是上节点、下节点、概率。
这是节点的视觉效果:
        6
/ | \
5 0 2
/ | \ / \
1 2 4 0 1
| /\ |
0 0 1 0
|
0
我希望能够选择一个起始节点并输出一个分支数组和累积概率,包括所有重复的分支。例如: start_node = 5应该回来
array([
[5, 1, 0.25],
[5, 2, 0.5],
[5, 4, 0.25],
[1, 0, 0.25],
[2, 0, 0.2],
[2, 1, 0.3],
[1, 0, 0.3]])
注意 [1, 0, x]分支被包含两次,因为它是由 [5, 1, 0.25] 提供的分支和 [2, 1, 0.3]分支。
这是我正在使用的一些代码,但对于我的应用程序(数百万个分支)来说太慢了:
def branch(start_node, paths):
output = paths[paths[:,0]==start_node]
next_nodes = output

while True:
can_go_lower = np.isin(next_nodes[:,1], paths[:,0])

if ~np.any(can_go_lower): break

next_nodes_checked = next_nodes[can_go_lower]

next_nodes = np.empty([0,3])
for nodes in next_nodes_checked:
to_append = paths[paths[:,0]==nodes[1]]
to_append[:,2] *= nodes[2]
next_nodes = np.append(next_nodes, to_append, axis=0)

output = np.append(output, next_nodes, axis=0)

return output
Twig 总是从高到低,因此陷入圈子不是问题。一种矢量化 for 的方法循环并避免 append s 将是最好的优化,我认为。

最佳答案

让我们将图形存储在 dict 中,而不是存储在 numpy 数组中.

tree = {k:arr[arr[:, 0] == k] for k in np.unique(arr[:, 0])}
制作为一组非叶节点:
non_leaf_nodes = set(np.unique(arr[:, 0]))
现在找到分支和累积概率:
def branch(start_node, tree, non_leaf_nodes):
curr_nodes = [[start_node, start_node, 1.0]] #(prev_node, starting_node, current_probability)
output = []
while True:
next_nodes = []
for _, node, prob in curr_nodes:
if node not in non_leaf_nodes: continue
subtree = tree[node]
to_append = subtree.copy()
to_append[:, 2] *= prob
to_append = to_append.tolist()
output += to_append
next_nodes += to_append
curr_nodes = next_nodes
if len(curr_nodes) == 0:
break
return np.array(output)
输出:
>>> branch(5, tree, non_leaf_nodes)

array([
[5. , 1. , 0.25],
[5. , 2. , 0.5 ],
[5. , 4. , 0.25],
[1. , 0. , 0.25],
[2. , 0. , 0.2 ],
[2. , 1. , 0.3 ],
[1. , 0. , 0.3 ]])
我期待它工作得更快。让我知道。

关于python - 高效的概率树分支,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67131111/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com