gpt4 book ai didi

python - 从头开始实现决策树的分支时遇到困难(长)

转载 作者:行者123 更新时间:2023-11-30 09:52:14 29 4
gpt4 key购买 nike

我目前正在使用 Python 从头开始​​实现决策树算法。我在实现树的分支时遇到了麻烦。在当前的实现中,我没有使用深度参数。

发生的情况是,要么分支结束得太快(如果我使用标志来防止无限递归),要么如果删除标志,我就会遇到无限递归。我也无法理解我是在主循环还是递归循环中。

我的数据非常简单:

d = {'one' : [1., 2., 3., 4.],
'two' : [4., 3., 2., 1.]}

df = pd.DataFrame(d)

df['three'] = (0,0,1,1)
df = np.array(df)

这会导致输出:

array([[ 1.,  4.,  0.],
[ 2., 3., 0.],
[ 3., 2., 1.],
[ 4., 1., 1.]])

我将使用 gini_index 进行分割。这个函数并不是解决我的问题所必需的,所以我将把它放在这个问题的末尾以帮助重现性。

我正在使用字典对象y,随着分支的扩展,它将继续包含嵌套字典。

                              y
/ \
y['left'] y['right']
/ \ \
y['left']['left'] y['left']['right'] y['right'] ['right']

接下来我将分解创建树的函数,这是我遇到一些问题的地方。

def create_tree2(node, flag ):   #node is a dictionary containing the root, which will contain nested dictionaries as this function recursively calls itself.

left, right =node['Groups'] # ['Groups'] is a key contains that contains two groups which will be used for the next split; I'm assigning them to left and right here
left,right = np.array(left), np.array(right) #just converting them to array because my other functions rely on the data to be in array format.


print ('left_group', left) #these are for debugging purposes.
print('right_group', right)

if flag == True and (right.size ==0 or left.size ==0):
node['left'] = left
node['right'] = right
flag = False
return

#This above portion is to prevent infinite loops.

关于无限递归,发生的情况是,如果我有两行数据,而不是将这两行拆分为两个不同的节点,我得到一个没有行的节点,而另一个节点有两行。

如果一个节点中的数据少于两行,我的循环通常会停止。所以空节点将终止,但是有两行数据的节点会再次 split ,分成一个空节点和一个两行填充节点。这个过程将永远持续下去。 所以我尝试使用标志来防止这种无限循环。该标志的唯一问题是,它似乎提前激活了一步。,它不会检查分割是否会导致两个节点或无限循环。例如:

A split leads to 
left = []
right = [ [ 3., 2., 1.],
[ 4., 1., 1.]])]

now instead of checking if the right can split further
(left =[3,2,1] , right = [ 4., 1., 1.]),

旗帜停在上面的步骤,太早了一步。

if len(left) < 2:
node['left'] =left
return

#Here I'm ending the node, if the len is less than 2 rows of data.




else:

node['left'] = check_split(left)
print('after left split', node['left']['Groups'])# for debugging purposes
create_tree2(node['left'], True)


#This is splitting the data and then recursively calling the create_tree2 function
#given that len of the group is NOT less than two.
#And the flag gets activated to prevent infinite looping.
#Notice that node['left'] is being used as the node parameter in the recursion function.





if len(right) <2:
node['right'] = right
return
else:
node['right'] = check_split(right)
print('right_check_split')
create_tree(node['right'],False)



#doing the same thing with the right side.

这里唯一的问题(所以我假设)是,如果左侧首先递归地调用自身,那么节点参数将更改为node['left']字典以及左右局部变量使用左分支信息进行更新。

让我们看看输出
下面是调用后代码的样子:

#first split

left_group [[ 1. 4. 0.]
[ 2. 3. 0.]]

right_group [[ 3. 2. 1.]
[ 4. 1. 1.]]






# first the left_group calls itself recursively producing an additional split
resulting in
a new left group that is empty, and a right_group has two rows

left_group []

right_group [[ 1. 4. 0.]
[ 2. 3. 0.]]


# now the `if` flag statement gets called
`if flag == True and (right.size ==0 or left.size ==0):
node['left'] = left
node['right'] = right
flag = False
return `


#ideally I want to do one more split on the right group,
to see if right group would split further but didn't know how to implement that properly. I'm assuming I would need some sort of counter?


#Next it jumps to the right main branch correctly.
not sure how as `right` was updated after the left's recursive function


right_check_split

left_group []
right_group [[ 3. 2. 1.]
[ 4. 1. 1.]]


This also activates the flag which stops the iteration. Ideally I would like this to go at least one more round to check if the right group [3,2,1] and [4,1,1] would split into two branches. Not sure how to do that?

我感到困惑的另一件事是为什么字典能够从右侧主节点开始,而不是从左侧的嵌套字典开始。

回想一下,递归首先发生在主左分支中

create_tree2(node['left'] , True), 

这应该更新 left 和 right 的值,然后当我们点击函数的这一部分时,这些值将继续存在:

if len(right) <2:
node['right'] = right
return
else:
node['right'] = check_split(right) #This right value would have been updated on?
print('right_check_split')
create_tree(node['right'],False)

所以我担心正确的值会被更改为 [[ 1.4.0.]
[ 2. 3. 0.]]
而是记住根节点的原始正确值,即

right_group [[ 3. 2. 1.]
[4.1.1.]].

所以我的问题是

1) 如何正确实现该标志以检查以确保在启动 if 标志循环 之前确实存在无限递归

2)尽管递归函数使用左分支值更新参数,但我的函数能够使用之前的右值(这就是我想要的),并且能够在适当的位置正确创建新的嵌套字典。

如果需要,这里是完整的代码

import numpy as np
import pandas as pd


d = {'one' : [1., 2., 3., 4.],
'two' : [4., 3., 2., 1.]}

df = pd.DataFrame(d)

df['three'] = (0,0,1,1)
df = np.array(df)




def split_method(data, index, value):
left, right = list(), list()
for row in data:
#for i in range((data.shape[-1] -1)):
if row[index] < value:
left.append(row)
else:
right.append(row)

return left, right





def gini(data,groups ):
data_size = len(data)


gini_index = 0

for group in groups:
group_size = len(group)
multiplier = float(group_size/data_size)
prob =1
if group_size == 0:
continue
print('multiplier', multiplier)
for value in set(data[:,-1]):
prob*= [row[-1] for row in group].count(value)/group_size
print ('prob', prob)
gini_index += (multiplier * prob)

return gini_index




def check_split(data):
main_score = 999
gini_index = 999
gini_value = 999
print('data', data)
for index in range(len(data[0])-1):
for rows in data:
value = rows[index]
groups =split_method(data, index, value)
gini_score =gini(data,groups)

if gini_score < main_score:
main_score = gini_score
gini_index, gini_value, gini_groups = index, value,np.array(groups)


return {'Index': gini_index, 'Value': gini_value, 'Groups': gini_groups}



def create_tree2(node, flag ):

left, right =node['Groups']
left,right = np.array(left), np.array(right)
print ('left_group', left)
print('right_group', right)

if flag == True and (right.size ==0 or left.size ==0):
node['left'] = left
node['right'] = right
flag = False
return

if len(left) < 2:
node['left'] =left
return

else:

node['left'] = check_split(left)
print('after left split', node['left']['Groups'])
create_tree2(node['left'],flag = True)

if len(right) <2:
node['right'] = right
return
else:
node['right'] = check_split(right)
print('right_check_split')
create_tree2(node['right'],flag =True)




return node



root = check_split(df) # this creates the root dictionary, (first dictionary)
y = create_tree2(root, False)

最佳答案

我对您的函数进行了这些更改:

def create_tree2(node, flag=False):

left, right =node['Groups']
left, right = np.array(left), np.array(right)
print('left_group', left)
print('right_group', right)

if flag == True and (right.size ==0 or left.size ==0):
node['left'] = left
node['right'] = right
flag = False
return

if len(left) < 2:
node['left'] = left
flag = True
print('too-small left. flag=True')
else:
node['left'] = check_split(left)
print('after left split', node['left']['Groups'])
create_tree2(node['left'],flag)

if len(right) < 2:
node['right'] = right
print('too-small right. flag=True')
flag = True
else:
node['right'] = check_split(right)
print('after right split', node['right']['Groups'])
create_tree2(node['right'], flag)

return node


d = {'one' : [1., 2., 3., 4.],
'two' : [4., 3., 2., 1.]}

df = pd.DataFrame(d)

df['three'] = (0,0,1,1)
df = np.array(df)

root = check_split(df)
y = create_tree2(root)

基本上,我使用了 len<2检查将标志设置为 True,然后允许右侧递归。我仍然认为这是不对的,因为 len==1 可能会发生一些事情。但不存在无限递归。

我得到了这个输出:

left_group [[ 1.  4.  0.]
[ 2. 3. 0.]]
right_group [[ 3. 2. 1.]
[ 4. 1. 1.]]
after left split [array([], shape=(0, 3), dtype=float64)
array([[ 1., 4., 0.],
[ 2., 3., 0.]])]
left_group []
right_group [[ 1. 4. 0.]
[ 2. 3. 0.]]
too-small left. flag=True
after right split [array([], shape=(0, 3), dtype=float64)
array([[ 1., 4., 0.],
[ 2., 3., 0.]])]
left_group []
right_group [[ 1. 4. 0.]
[ 2. 3. 0.]]
after right split [array([], shape=(0, 3), dtype=float64)
array([[ 3., 2., 1.],
[ 4., 1., 1.]])]
left_group []
right_group [[ 3. 2. 1.]
[ 4. 1. 1.]]
too-small left. flag=True
after right split [array([], shape=(0, 3), dtype=float64)
array([[ 3., 2., 1.],
[ 4., 1., 1.]])]
left_group []
right_group [[ 3. 2. 1.]
[ 4. 1. 1.]]
Y= {'Groups': array([[[ 1., 4., 0.],
[ 2., 3., 0.]],

[[ 3., 2., 1.],
[ 4., 1., 1.]]]), 'Index': 0, 'right': {'Groups': array([array([], shape=(0, 3), dtype=float64),
array([[ 3., 2., 1.],
[ 4., 1., 1.]])], dtype=object), 'Index': 0, 'right': {'Groups': array([array([], shape=(0, 3), dtype=float64),
array([[ 3., 2., 1.],
[ 4., 1., 1.]])], dtype=object), 'Index': 0, 'right': array([[ 3., 2., 1.],
[ 4., 1., 1.]]), 'Value': 3.0, 'left': array([], shape=(0, 3), dtype=float64)}, 'Value': 3.0, 'left': array([], shape=(0, 3), dtype=float64)}, 'Value': 3.0, 'left': {'Groups': array([array([], shape=(0, 3), dtype=float64),
array([[ 1., 4., 0.],
[ 2., 3., 0.]])], dtype=object), 'Index': 0, 'right': {'Groups': array([array([], shape=(0, 3), dtype=float64),
array([[ 1., 4., 0.],
[ 2., 3., 0.]])], dtype=object), 'Index': 0, 'right': array([[ 1., 4., 0.],
[ 2., 3., 0.]]), 'Value': 1.0, 'left': array([], shape=(0, 3), dtype=float64)}, 'Value': 1.0, 'left': array([], shape=(0, 3), dtype=float64)}}

此外,我认为您可以通过在最后检查节点的左侧或右侧是否为空,将相对的节点向上拉一个来优化这一点。像这样的东西:

if node['left'] is empty:
kid = node['right']
node.clear()
for k,v in kid.items():
node[k]=v
elif node['right'] is empty:
same basic thing, with left kid

检查是否为空是一个技巧,因为有时它是一个字典,有时不是。

最后,您似乎没有存储实际的拆分信息。这不就是决策树的意义所在吗——知道要比较哪些因素?难道不应该记录每个节点的列和值吗?

关于python - 从头开始实现决策树的分支时遇到困难(长),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43190688/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com