Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Support all parameters in 'DecisionTree' class
This also lifts the restriction of 'max_depth=1' in the 'Tree' class.

Signed-off-by: Niklas Koep <niklas.koep@gmail.com>
  • Loading branch information
nkoep committed Aug 11, 2020
1 parent 77e2c3b commit 2c7b468
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions tree.py
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin

from random_state import ensure_random_state


njit_cached = numba.njit(cache=True)

Expand Down Expand Up @@ -77,8 +79,7 @@ class Tree:
----------
max_depth : int or None
The maximum allowed tree depth. In general, this requires pruning the
tree to select the best subtree configuration. For simplicity, we only
allow `max_depth=1`.
tree to select the best subtree configuration.
min_samples_split : int
The minimum number of samples required to split an internal node.
max_features : int or None
Expand Down Expand Up @@ -109,8 +110,8 @@ def __init__(self, max_depth=None, min_samples_split=2, max_features=None,
self._max_features = max_features
self._random_state = random_state

if self._max_depth is not None:
assert self._max_depth == 1, "Only 'max_depth=1' allowed"
if self._max_depth is None:
self._max_depth = np.inf
if self._max_features is not None:
assert self._random_state is not None, "No random state provided"

Expand All @@ -120,7 +121,7 @@ def __init__(self, max_depth=None, min_samples_split=2, max_features=None,
self.threshold = None
self.prediction = None

def construct_tree(self, X, y):
def construct_tree(self, X, y, depth=0):
"""Construct the binary decision tree via recursive splitting.
Parameters
Expand All @@ -133,12 +134,14 @@ def construct_tree(self, X, y):
num_samples, num_features = X.shape

# Too few samples to split, so turn the node into a leaf.
if num_samples < self._min_samples_split or self._max_depth == 1:
if num_samples < self._min_samples_split or depth >= self._max_depth:
self.prediction = y.mean()
return

random_state = ensure_random_state(self._random_state)

if self._max_features is not None:
feature_indices = self._random_state.integers(
feature_indices = random_state.integers(
num_features, size=min(self._max_features, num_features))
else:
feature_indices = np.arange(num_features)
Expand Down Expand Up @@ -168,16 +171,18 @@ def construct_tree(self, X, y):
mask_left, mask_right = split["partition"]

self.left = Tree(
max_depth=self._max_depth,
min_samples_split=self._min_samples_split,
max_features=self._max_features,
random_state=self._random_state)
self.left.construct_tree(X[mask_left, :], y[mask_left])
random_state=random_state)
self.left.construct_tree(X[mask_left, :], y[mask_left], depth+1)

self.right = Tree(
max_depth=self._max_depth,
min_samples_split=self._min_samples_split,
max_features=self._max_features,
random_state=self._random_state)
self.right.construct_tree(X[mask_right, :], y[mask_right])
random_state=random_state)
self.right.construct_tree(X[mask_right, :], y[mask_right], depth+1)

def apply_to_sample(self, x):
"""Perform regression on a single observation.
Expand Down Expand Up @@ -222,8 +227,7 @@ class DecisionTree(BaseEstimator, RegressorMixin):
----------
max_depth : int or None
The maximum allowed tree depth. In general, this requires pruning the
tree to select the best subtree configuration. For simplicity, we only
allow `max_depth=1`.
tree to select the best subtree configuration.
min_samples_split : int
The minimum number of samples required to split an internal node.
max_features : int or None
Expand Down Expand Up @@ -257,7 +261,10 @@ def fit(self, X, y):
-------
self : DecisionTree
"""
self.tree_ = Tree(min_samples_split=self.min_samples_split)
self.tree_ = Tree(max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
max_features=self.max_features,
random_state=self.random_state)
self.tree_.construct_tree(*map(np.array, (X, y)))
return self

Expand Down

0 comments on commit 2c7b468

Please sign in to comment.