Skip to content

Commit

Permalink
Add Ruby variation of the decision tree classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Darius Morawiec committed Sep 25, 2017
1 parent a9e8a96 commit a404c4f
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 3 deletions.
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Transpile trained [scikit-learn](https:/scikit-learn/scikit-learn) e
<td align="center"><a href="examples/classifier/DecisionTreeClassifier/js/basics.ipynb">✓</a></td>
<td align="center"></td>
<td align="center"><a href="examples/classifier/DecisionTreeClassifier/php/basics.ipynb">✓</a></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td><a href="http://scikit-learn.org/0.18/modules/generated/sklearn.ensemble.RandomForestClassifier.html">sklearn.ensemble.RandomForestClassifier</a></td>
Expand Down
16 changes: 14 additions & 2 deletions sklearn_porter/classifier/DecisionTreeClassifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ class DecisionTreeClassifier(Classifier):
'arr': '$classes[{0}] = {1}',
'indent': ' ',
'join': '; ',
},
'ruby': {
# 'init': '{name} = {value}',
# 'type': '{0}',
'if': 'if atts[{0}] {1} {2}',
'else': 'else',
'endif': 'end',
'arr': 'classes[{0}] = {1}',
# 'arr[]': '{name} = [{values}]',
# 'arr[][]': '{name} = [{values}]',
'indent': ' ',
'join': ' ',
}
}
# @formatter:on
Expand Down Expand Up @@ -179,7 +191,7 @@ def create_tree(self):
if self.n_features > 1 or (self.n_features == 1 and i >= 0):
feature_indices.append([str(j) for j in range(n_features)][i])

indentation = 1 if self.target_language in ['java', 'js', 'php'] else 0
indentation = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0
return self.create_branches(
self.model.tree_.children_left,
self.model.tree_.children_right,
Expand All @@ -196,7 +208,7 @@ def create_method(self, class_name, method_name):
:return out : string
The built method as string.
"""
n_indents = 1 if self.target_language in ['java', 'js', 'php'] else 0
n_indents = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0
branches = self.indent(self.create_tree(), n_indents=1)
temp_method = self.temp('method', n_indents=n_indents, skipping=True)
out = temp_method.format(class_name=class_name, method_name=method_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class {class_name}
{method}
end

if ARGV.length == {n_features}
atts = ARGV.collect {{ |i| i.to_f }}
puts {class_name}.{method_name}(atts)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def self.{method_name} (atts)
classes = Array.new({n_classes}, 0)
{branches}

pos = classes.each_with_index.select {{|e, i| e==classes.max}}.map &:last
return pos.min
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-

from unittest import TestCase

from sklearn.tree import DecisionTreeClassifier

from ..Classifier import Classifier
from ...language.Ruby import Ruby


class DecisionTreeClassifierRubyTest(Ruby, Classifier, TestCase):

def setUp(self):
super(DecisionTreeClassifierRubyTest, self).setUp()
self.mdl = DecisionTreeClassifier(random_state=0)

def tearDown(self):
super(DecisionTreeClassifierRubyTest, self).tearDown()

0 comments on commit a404c4f

Please sign in to comment.