python上的树递归
我有下面的树算法,它打印每个叶的条件:python上的树递归,python,recursion,Python,Recursion,我有下面的树算法,它打印每个叶的条件: def _grow_tree(self, X, y, depth=0): # Identify best split idx, thr = self._best_split(X, y) # Indentation for tree description indent = " " * depth indices_left = X.iloc[:, idx] < thr X_left = X[in
def _grow_tree(self, X, y, depth=0):
# Identify best split
idx, thr = self._best_split(X, y)
# Indentation for tree description
indent = " " * depth
indices_left = X.iloc[:, idx] < thr
X_left = X[indices_left]
y_left = y_train[X_left.reset_index().loc[:,'id'].values]
X_right = X[~indices_left]
y_right = y_train[X_right.reset_index().loc[:,'id'].values]
self.tree_describe.append(indent +"if x['"+ X.columns[idx] + "'] <= " +\
str(thr) + ':')
# Grow on left side of the tree
node.left = self._grow_tree(X_left, y_left, depth + 1)
self.tree_describe.append(indent +"else: #if x['"+ X.columns[idx] + "'] > " +\
str(thr) + ':')
# Grow on right side of the tree
node.right = self._grow_tree(X_right, y_right, depth + 1)
return node
这将为特定案例生成以下打印:
["if x['VAR1'] <= 0.5:",
" if x['VAR2'] <= 0.5:",
" else: #if x['VAR2'] > 0.5:",
"else: #if x['VAR1'] > 0.5:",
" if x['VAR3'] <= 0.5:",
" else: #if x['VAR3'] > 0.5:"]
如何获得以下输出:
["if x['VAR1'] <= 0.5:",
" if x['VAR1'] <= 0.5&x['VAR2'] <= 0.5",
" else: #if x['VAR1'] <= 0.5&x['VAR2'] > 0.5:",
"else: #if x['VAR1'] > 0.5:",
" if x['VAR1'] > 0.5&x['VAR3'] <= 0.5:",
" else: #if x['VAR1'] > 0.5&x['VAR3'] > 0.5:"]
您可以在函数中引入一个新参数,该参数将包含需要添加到每个深层条件中的具有更高级别条件的字符串: 我还建议为字符串构建使用.format:
def _grow_tree(self, X, y, depth=0, descr=""):
idx, thr = self._best_split(X, y)
indent = " " * depth
cond = "x['{}'] <= {}{}".format(X.columns[idx], thr, descr)
self.tree_describe.append("{}if {}:".format(indent, cond))
node.left = self._grow_tree(X_left, y_left, depth + 1, " & " + cond)
cond = "x['{}'] > {}{}".format(X.columns[idx], thr, descr)
self.tree_describe.append("{}else: #if {}:".format(indent, cond))
node.right = self._grow_tree(X_right, y_right, depth + 1, " & " + cond)
return node
您不是因为不想要这个输出而进行缩进吗?因为现在缩进显示哪个项是另一项的子项,现在您想重复x['VAR1']0.5部分来显示它。最初是的,但现在我打算使用pandas查询函数根据条件创建叶列,第一种方法不实用。我需要让每片叶子都具备所有的条件,所以就像你在深度上做的那样,把它向前推。深度增长+1,描述元素将随条件增长。