Skip to content

Commit

Permalink
Add Ruby variation of the ExtraTreesClassifier estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
Darius Morawiec committed Sep 26, 2017
1 parent 3775501 commit 81b9914
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Transpile trained [scikit-learn](https:/scikit-learn/scikit-learn) e
<td align="center"><a href="examples/classifier/ExtraTreesClassifier/js/basics.ipynb">✓</a></td>
<td align="center"></td>
<td align="center">✓</td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td><a href="http://scikit-learn.org/0.18/modules/generated/sklearn.ensemble.AdaBoostClassifier.html">sklearn.ensemble.AdaBoostClassifier</a></td>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class {class_name}

{method}

end

if ARGV.length == {n_features}
atts = ARGV.collect {{ |i| i.to_f }}
puts {class_name}.{method_name}(atts)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{methods}
def self.{method_name} (atts)
classes = Array.new({n_classes}, 0)
{method_calls}
pos_max = classes.each_with_index.select {{|e, i| e==classes.max}}.map &:last
return pos_max.min
end
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
idx = {class_name}.{method_name}(atts); classes[idx] = classes[idx] + 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def self.{method_name}_{method_id} (atts)
classes = Array.new({n_classes}, 0)
{tree_branches}
pos_max = classes.each_with_index.select {{|e, i| e==classes.max}}.map &:last
return pos_max.min
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-

import unittest

from sklearn.ensemble import ExtraTreesClassifier

from ..Classifier import Classifier
from ...language.Ruby import Ruby


class ExtraTreesClassifierRubyTest(Ruby, Classifier, unittest.TestCase):

def setUp(self):
super(ExtraTreesClassifierRubyTest, self).setUp()
self.mdl = ExtraTreesClassifier(random_state=0)

def tearDown(self):
super(ExtraTreesClassifierRubyTest, self).tearDown()

# @unittest.skip('The generated code would be too large.')
# def test_existing_features_w_digits_data(self):
# pass
#
# @unittest.skip('The generated code would be too large.')
# def test_random_features_w_digits_data(self):
# pass

0 comments on commit 81b9914

Please sign in to comment.