diff --git a/readme.md b/readme.md index d0659038..33509fe1 100644 --- a/readme.md +++ b/readme.md @@ -61,7 +61,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e - + ✓ sklearn.ensemble.RandomForestClassifier diff --git a/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py b/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py index b368bc9f..37736721 100644 --- a/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py +++ b/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py @@ -47,6 +47,18 @@ class DecisionTreeClassifier(Classifier): 'arr': '$classes[{0}] = {1}', 'indent': ' ', 'join': '; ', + }, + 'ruby': { + # 'init': '{name} = {value}', + # 'type': '{0}', + 'if': 'if atts[{0}] {1} {2}', + 'else': 'else', + 'endif': 'end', + 'arr': 'classes[{0}] = {1}', + # 'arr[]': '{name} = [{values}]', + # 'arr[][]': '{name} = [{values}]', + 'indent': ' ', + 'join': ' ', } } # @formatter:on @@ -179,7 +191,7 @@ def create_tree(self): if self.n_features > 1 or (self.n_features == 1 and i >= 0): feature_indices.append([str(j) for j in range(n_features)][i]) - indentation = 1 if self.target_language in ['java', 'js', 'php'] else 0 + indentation = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0 return self.create_branches( self.model.tree_.children_left, self.model.tree_.children_right, @@ -196,7 +208,7 @@ def create_method(self, class_name, method_name): :return out : string The built method as string. """ - n_indents = 1 if self.target_language in ['java', 'js', 'php'] else 0 + n_indents = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0 branches = self.indent(self.create_tree(), n_indents=1) temp_method = self.temp('method', n_indents=n_indents, skipping=True) out = temp_method.format(class_name=class_name, method_name=method_name, diff --git a/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/class.txt b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/class.txt new file mode 100644 index 00000000..62a9fc64 --- /dev/null +++ b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/class.txt @@ -0,0 +1,8 @@ +class {class_name} + {method} +end + +if ARGV.length == {n_features} + atts = ARGV.collect {{ |i| i.to_f }} + puts {class_name}.{method_name}(atts) +end \ No newline at end of file diff --git a/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/method.txt b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/method.txt new file mode 100644 index 00000000..d751203b --- /dev/null +++ b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/method.txt @@ -0,0 +1,7 @@ +def self.{method_name} (atts) + classes = Array.new({n_classes}, 0) + {branches} + + pos = classes.each_with_index.select {{|e, i| e==classes.max}}.map &:last + return pos.min +end \ No newline at end of file diff --git a/tests/classifier/DecisionTreeClassifier/DecisionTreeClassifierRubyTest.py b/tests/classifier/DecisionTreeClassifier/DecisionTreeClassifierRubyTest.py new file mode 100644 index 00000000..cfcbf89d --- /dev/null +++ b/tests/classifier/DecisionTreeClassifier/DecisionTreeClassifierRubyTest.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- + +from unittest import TestCase + +from sklearn.tree import DecisionTreeClassifier + +from ..Classifier import Classifier +from ...language.Ruby import Ruby + + +class DecisionTreeClassifierRubyTest(Ruby, Classifier, TestCase): + + def setUp(self): + super(DecisionTreeClassifierRubyTest, self).setUp() + self.mdl = DecisionTreeClassifier(random_state=0) + + def tearDown(self): + super(DecisionTreeClassifierRubyTest, self).tearDown()