Skip to content

Commit

Permalink
Add '--export' attribute and file handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Darius Morawiec committed Dec 3, 2017
1 parent 5baa0f2 commit 0669645
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 23 deletions.
2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ First of all have a quick view on the available arguments:
$ python -m sklearn_porter [-h] --input <PICKLE_FILE> [--output <DEST_DIR>] \
[--class_name <CLASS_NAME>] [--method_name <METHOD_NAME>] \
[--c] [--java] [--js] [--go] [--php] [--ruby] \
[--pipe]
[--export] [--pipe]
```

The following example shows how you can save an trained estimator to the [pickle format](http://scikit-learn.org/stable/modules/model_persistence.html#persistence-example):
Expand Down
10 changes: 5 additions & 5 deletions sklearn_porter/Porter.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def export(self, class_name=None, method_name=None,
with further information.
"""

if class_name is None:
if class_name is None or class_name == '':
class_name = self.estimator_name

if method_name is None:
if method_name is None or method_name == '':
method_name = self.target_method

if isinstance(num_format, types.LambdaType):
Expand All @@ -196,7 +196,7 @@ def export(self, class_name=None, method_name=None,
class_name,
language)
output = {
'model': str(output),
'estimator': str(output),
'filename': filename,
'class_name': class_name,
'method_name': method_name,
Expand Down Expand Up @@ -495,12 +495,12 @@ def _get_filename(class_name, language):
filename : str
The generated filename.
"""
name = str(class_name).lower()
name = str(class_name).strip()
lang = str(language)

# Name:
if language in ['java', 'php']:
name = name.capitalize()
name = "".join([name[0].upper() + name[1:]])

# Suffix:
suffix = {
Expand Down
37 changes: 21 additions & 16 deletions sklearn_porter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@ def parse_args(args):
'stored.'))
optional.add_argument(
'--class_name',
default='Brain',
default=None,
required=False,
help='Define the class name in the final output.')
optional.add_argument(
'--method_name',
default='predict',
required=False,
help='Define the method name in the final output.')
optional.add_argument(
'--export', '-e',
required=False,
default=False,
action='store_true',
help='Whether to export the model data or not.')
optional.add_argument(
'--pipe', '-p',
required=False,
Expand Down Expand Up @@ -91,37 +97,36 @@ def main():
language = key
break

# Define destination path:
dest_dir = str(args.get('output'))
if dest_dir == '' or not os.path.isdir(dest_dir):
dest_dir = input_path.split(os.sep)
del dest_dir[-1]
dest_dir = os.sep.join(dest_dir)

# Port estimator:
try:
porter = Porter(estimator, language=language)
class_name = str(args.get('class_name'))
method_name = str(args.get('method_name'))
class_name = args.get('class_name')
method_name = args.get('method_name')
output = porter.export(class_name=class_name,
method_name=method_name,
output=str(args.get('output')),
export_dir=dest_dir,
export_data=bool(args.get('export')),
details=True)
except Exception as e:
sys.exit('Error: {}'.format(str(e)))
else:
# Print transpiled estimator to the console:
if bool(args.get('pipe', False)):
print(output.get('model'))
print(output.get('estimator'))
sys.exit(0)

# Define destination path:
dest_dir = str(args.get('output'))
filename = output.get('filename')
if dest_dir != '' and os.path.isdir(dest_dir):
dest_path = os.path.join(dest_dir, filename)
else:
dest_dir = input_path.split(os.sep)
del dest_dir[-1]
dest_dir += [filename]
dest_path = os.sep.join(dest_dir)

dest_path = dest_dir + os.sep + filename
# Save transpiled estimator:
with open(dest_path, 'w') as file_:
file_.write(output.get('model'))
file_.write(output.get('estimator'))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/PorterTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_python_command_execution(self):
joblib.dump(self.estimator, pkl_path)

# Port estimator:
cmd = 'python -m sklearn_porter -i {}'.format(pkl_path).split()
cmd = 'python -m sklearn_porter -i {} --class_name Brain'.format(pkl_path).split()
subp.call(cmd)
# Compare file contents:
equal = filecmp.cmp(cp_src, cp_dest)
Expand Down

0 comments on commit 0669645

Please sign in to comment.