diff --git a/readme.md b/readme.md
index d0659038..33509fe1 100644
--- a/readme.md
+++ b/readme.md
@@ -61,7 +61,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
✓ |
|
✓ |
- |
+ ✓ |
sklearn.ensemble.RandomForestClassifier |
diff --git a/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py b/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py
index b368bc9f..37736721 100644
--- a/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py
+++ b/sklearn_porter/classifier/DecisionTreeClassifier/__init__.py
@@ -47,6 +47,18 @@ class DecisionTreeClassifier(Classifier):
'arr': '$classes[{0}] = {1}',
'indent': ' ',
'join': '; ',
+ },
+ 'ruby': {
+ # 'init': '{name} = {value}',
+ # 'type': '{0}',
+ 'if': 'if atts[{0}] {1} {2}',
+ 'else': 'else',
+ 'endif': 'end',
+ 'arr': 'classes[{0}] = {1}',
+ # 'arr[]': '{name} = [{values}]',
+ # 'arr[][]': '{name} = [{values}]',
+ 'indent': ' ',
+ 'join': ' ',
}
}
# @formatter:on
@@ -179,7 +191,7 @@ def create_tree(self):
if self.n_features > 1 or (self.n_features == 1 and i >= 0):
feature_indices.append([str(j) for j in range(n_features)][i])
- indentation = 1 if self.target_language in ['java', 'js', 'php'] else 0
+ indentation = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0
return self.create_branches(
self.model.tree_.children_left,
self.model.tree_.children_right,
@@ -196,7 +208,7 @@ def create_method(self, class_name, method_name):
:return out : string
The built method as string.
"""
- n_indents = 1 if self.target_language in ['java', 'js', 'php'] else 0
+ n_indents = 1 if self.target_language in ['java', 'js', 'php', 'ruby'] else 0
branches = self.indent(self.create_tree(), n_indents=1)
temp_method = self.temp('method', n_indents=n_indents, skipping=True)
out = temp_method.format(class_name=class_name, method_name=method_name,
diff --git a/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/class.txt b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/class.txt
new file mode 100644
index 00000000..62a9fc64
--- /dev/null
+++ b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/class.txt
@@ -0,0 +1,8 @@
+class {class_name}
+ {method}
+end
+
+if ARGV.length == {n_features}
+ atts = ARGV.collect {{ |i| i.to_f }}
+ puts {class_name}.{method_name}(atts)
+end
\ No newline at end of file
diff --git a/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/method.txt b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/method.txt
new file mode 100644
index 00000000..d751203b
--- /dev/null
+++ b/sklearn_porter/classifier/DecisionTreeClassifier/templates/ruby/method.txt
@@ -0,0 +1,7 @@
+def self.{method_name} (atts)
+ classes = Array.new({n_classes}, 0)
+ {branches}
+
+ pos = classes.each_with_index.select {{|e, i| e==classes.max}}.map &:last
+ return pos.min
+end
\ No newline at end of file
diff --git a/tests/classifier/DecisionTreeClassifier/DecisionTreeClassifierRubyTest.py b/tests/classifier/DecisionTreeClassifier/DecisionTreeClassifierRubyTest.py
new file mode 100644
index 00000000..cfcbf89d
--- /dev/null
+++ b/tests/classifier/DecisionTreeClassifier/DecisionTreeClassifierRubyTest.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+
+from unittest import TestCase
+
+from sklearn.tree import DecisionTreeClassifier
+
+from ..Classifier import Classifier
+from ...language.Ruby import Ruby
+
+
+class DecisionTreeClassifierRubyTest(Ruby, Classifier, TestCase):
+
+ def setUp(self):
+ super(DecisionTreeClassifierRubyTest, self).setUp()
+ self.mdl = DecisionTreeClassifier(random_state=0)
+
+ def tearDown(self):
+ super(DecisionTreeClassifierRubyTest, self).tearDown()