Skip to content

Commit

Permalink
Add PHP variation of the RandomForestClassifier estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
Darius Morawiec committed Sep 26, 2017
1 parent db63686 commit faac38d
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 3 deletions.
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Transpile trained [scikit-learn](https:/scikit-learn/scikit-learn) e
<td align="center"><a href="examples/classifier/RandomForestClassifier/java/basics.ipynb">✓</a></td>
<td align="center"><a href="examples/classifier/RandomForestClassifier/js/basics.ipynb">✓</a></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
Expand Down
12 changes: 10 additions & 2 deletions sklearn_porter/classifier/RandomForestClassifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ class RandomForestClassifier(Classifier):
'arr': 'classes[{0}] = {1}',
'indent': ' ',
'join': '; ',
}
},
'php': {
'if': 'if ($atts[{0}] {1} {2}) {{',
'else': '} else {',
'endif': '}',
'arr': '$classes[{0}] = {1}',
'indent': ' ',
'join': '; ',
},
}
# @formatter:on

Expand Down Expand Up @@ -232,7 +240,7 @@ def create_method(self):
fns = '\n'.join(fns)

# Merge generated content:
n_indents = 1 if self.target_language in ['java', 'js'] else 0
n_indents = 1 if self.target_language in ['java', 'js', 'php'] else 0
temp_method = self.temp('method')
out = temp_method.format(method_name=self.method_name,
method_calls=fn_names, methods=fns,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<?php

class {class_name} {{

{method}

}}

if ($argc > 1) {{
array_shift($argv);
$prediction = {class_name}::{method_name}($argv);
fwrite(STDOUT, $prediction);
exit(0);
}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{methods}
public static function {method_name}($atts) {{
$n_classes = {n_classes};
$classes = array_fill(0, $n_classes, 0);
{method_calls}
$class_idx = 0;
$class_val = $classes[0];
for ($i = 0; $i < $n_classes; $i++) {{
if ($classes[$i] > $class_val) {{
$class_idx = $i;
$class_val = $classes[$i];
}}
}}
return $class_idx;
}}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
$classes[{class_name}::{method_name}($atts)]++;
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
public static function {method_name}_{method_id}($atts) {{
$n_classes = {n_classes};
$classes = array_fill(0, $n_classes, 0);
{tree_branches}
$class_idx = 0;
$class_val = $classes[0];
for ($i = 0; $i < $n_classes; $i++) {{
if ($classes[$i] > $class_val) {{
$class_idx = $i;
$class_val = $classes[$i];
}}
}}
return $class_idx;
}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-

import unittest
from unittest import TestCase

from sklearn.ensemble import RandomForestClassifier

from ..Classifier import Classifier
from ...language.PHP import PHP


class RandomForestClassifierPHPTest(PHP, Classifier, TestCase):

def setUp(self):
super(RandomForestClassifierPHPTest, self).setUp()
self.mdl = RandomForestClassifier(n_estimators=100, random_state=0)

def tearDown(self):
super(RandomForestClassifierPHPTest, self).tearDown()

@unittest.skip('The generated code would be too large.')
def test_existing_features_w_digits_data(self):
pass

@unittest.skip('The generated code would be too large.')
def test_random_features_w_digits_data(self):
pass

0 comments on commit faac38d

Please sign in to comment.