Skip to content

Commit

Permalink
Merge pull request #68 from slickml/amir/dev
Browse files Browse the repository at this point in the history
Fixed save_paths in plotting and prepared for new version
  • Loading branch information
amirhessam88 authored May 16, 2021
2 parents c81d723 + 4cf63cc commit 8817eca
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 81 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@ All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased


## Version 0.1.3 - 2021-05-15

### Fixed
* [#66](https:/slickml/slick-ml/pull/66) fixed bugs in feature selection algorithm.
* [#67](https:/slickml/slick-ml/pull/67) fixed bugs in metrics.

### Updated
* [#66](https:/slickml/slick-ml/pull/66) updated the order of the functions inside each class.
* [#68](https:/slickml/slick-ml/pull/68) updated `save_path` in plotting functions.
* [#68](https:/slickml/slick-ml/pull/68) updated `bibtex` citations to software.

### Added
* [#68](https:/slickml/slick-ml/pull/68) added directories for `JOSS` and `NeurIPS` papers.


## Version 0.1.2 - 2021-04-17

Expand Down
22 changes: 7 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,16 @@ let others know that you are working on it. Whether the contributions consists o
levels. The SlickML community goals are to be helpful and effective.

## Citing SlickML
If you use SlickML in academic work, please consider citing
https://doi.org/10.1117/12.2304418 .
If you use SlickML in academic work, please consider citing it.

### Bibtex Entry:
```bib
@inproceedings{tahmassebi2018ideeple,
title={ideeple: Deep learning in a flash},
author={Tahmassebi, Amirhessam},
booktitle={Disruptive Technologies in Information Sciences},
volume={10652},
pages={106520S},
year={2018},
organization={International Society for Optics and Photonics}
@software{slickml2020,
title={SlickML: Slick Machine Learning in Python},
author={Tahmassebi, Amirhessam and Smith, Trace},
url={https:/slickml/slick-ml},
version={0.1.3},
year={2021},
}
```
### APA Entry:

Tahmassebi, A. (2018, May). ideeple: Deep learning in a flash. In Disruptive
Technologies in Information Sciences (Vol. 10652, p. 106520S). International
Society for Optics and Photonics.

55 changes: 28 additions & 27 deletions examples/metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -61,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -101,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -131,7 +131,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -169,47 +169,47 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\" >\n",
" #T_b4381_ th {\n",
" #T_2c43c_ th {\n",
" font-size: 12px;\n",
" text-align: left;\n",
" font-weight: bold;\n",
" } #T_b4381_ td {\n",
" } #T_2c43c_ td {\n",
" font-size: 12px;\n",
" text-align: center;\n",
" }#T_b4381_row0_col0,#T_b4381_row0_col1,#T_b4381_row0_col2,#T_b4381_row0_col3,#T_b4381_row0_col4,#T_b4381_row0_col5,#T_b4381_row0_col6,#T_b4381_row0_col7,#T_b4381_row0_col8,#T_b4381_row0_col9,#T_b4381_row0_col10,#T_b4381_row0_col11,#T_b4381_row0_col12,#T_b4381_row0_col13,#T_b4381_row0_col14{\n",
" }#T_2c43c_row0_col0,#T_2c43c_row0_col1,#T_2c43c_row0_col2,#T_2c43c_row0_col3,#T_2c43c_row0_col4,#T_2c43c_row0_col5,#T_2c43c_row0_col6,#T_2c43c_row0_col7,#T_2c43c_row0_col8,#T_2c43c_row0_col9,#T_2c43c_row0_col10,#T_2c43c_row0_col11,#T_2c43c_row0_col12,#T_2c43c_row0_col13,#T_2c43c_row0_col14{\n",
" background-color: #e5e5ff;\n",
" color: #000000;\n",
" }</style><table id=\"T_b4381_\" ><thead> <tr> <th class=\"blank level0\" ></th> <th class=\"col_heading level0 col0\" >Accuracy</th> <th class=\"col_heading level0 col1\" >Balanced Accuracy</th> <th class=\"col_heading level0 col2\" >ROC AUC</th> <th class=\"col_heading level0 col3\" >PR AUC</th> <th class=\"col_heading level0 col4\" >Precision</th> <th class=\"col_heading level0 col5\" >Recall</th> <th class=\"col_heading level0 col6\" >Average Precision</th> <th class=\"col_heading level0 col7\" >F-1 Score</th> <th class=\"col_heading level0 col8\" >F-2 Score</th> <th class=\"col_heading level0 col9\" >F-0.50 Score</th> <th class=\"col_heading level0 col10\" >Threat Score</th> <th class=\"col_heading level0 col11\" >TP</th> <th class=\"col_heading level0 col12\" >TN</th> <th class=\"col_heading level0 col13\" >FP</th> <th class=\"col_heading level0 col14\" >FN</th> </tr></thead><tbody>\n",
" }</style><table id=\"T_2c43c_\" ><thead> <tr> <th class=\"blank level0\" ></th> <th class=\"col_heading level0 col0\" >Accuracy</th> <th class=\"col_heading level0 col1\" >Balanced Accuracy</th> <th class=\"col_heading level0 col2\" >ROC AUC</th> <th class=\"col_heading level0 col3\" >PR AUC</th> <th class=\"col_heading level0 col4\" >Precision</th> <th class=\"col_heading level0 col5\" >Recall</th> <th class=\"col_heading level0 col6\" >Average Precision</th> <th class=\"col_heading level0 col7\" >F-1 Score</th> <th class=\"col_heading level0 col8\" >F-2 Score</th> <th class=\"col_heading level0 col9\" >F-0.50 Score</th> <th class=\"col_heading level0 col10\" >Threat Score</th> <th class=\"col_heading level0 col11\" >TP</th> <th class=\"col_heading level0 col12\" >TN</th> <th class=\"col_heading level0 col13\" >FP</th> <th class=\"col_heading level0 col14\" >FN</th> </tr></thead><tbody>\n",
" <tr>\n",
" <th id=\"T_b4381_level0_row0\" class=\"row_heading level0 row0\" >Threshold = 0.500 | Average =\n",
" <th id=\"T_2c43c_level0_row0\" class=\"row_heading level0 row0\" >Threshold = 0.500 | Average =\n",
" Binary</th>\n",
" <td id=\"T_b4381_row0_col0\" class=\"data row0 col0\" >0.968000</td>\n",
" <td id=\"T_b4381_row0_col1\" class=\"data row0 col1\" >0.957000</td>\n",
" <td id=\"T_b4381_row0_col2\" class=\"data row0 col2\" >0.988000</td>\n",
" <td id=\"T_b4381_row0_col3\" class=\"data row0 col3\" >0.992000</td>\n",
" <td id=\"T_b4381_row0_col4\" class=\"data row0 col4\" >0.952000</td>\n",
" <td id=\"T_b4381_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_b4381_row0_col6\" class=\"data row0 col6\" >0.991000</td>\n",
" <td id=\"T_b4381_row0_col7\" class=\"data row0 col7\" >0.975000</td>\n",
" <td id=\"T_b4381_row0_col8\" class=\"data row0 col8\" >0.990000</td>\n",
" <td id=\"T_b4381_row0_col9\" class=\"data row0 col9\" >0.961000</td>\n",
" <td id=\"T_b4381_row0_col10\" class=\"data row0 col10\" >0.952000</td>\n",
" <td id=\"T_b4381_row0_col11\" class=\"data row0 col11\" >118</td>\n",
" <td id=\"T_b4381_row0_col12\" class=\"data row0 col12\" >64</td>\n",
" <td id=\"T_b4381_row0_col13\" class=\"data row0 col13\" >6</td>\n",
" <td id=\"T_b4381_row0_col14\" class=\"data row0 col14\" >0</td>\n",
" <td id=\"T_2c43c_row0_col0\" class=\"data row0 col0\" >0.968000</td>\n",
" <td id=\"T_2c43c_row0_col1\" class=\"data row0 col1\" >0.957000</td>\n",
" <td id=\"T_2c43c_row0_col2\" class=\"data row0 col2\" >0.988000</td>\n",
" <td id=\"T_2c43c_row0_col3\" class=\"data row0 col3\" >0.992000</td>\n",
" <td id=\"T_2c43c_row0_col4\" class=\"data row0 col4\" >0.952000</td>\n",
" <td id=\"T_2c43c_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_2c43c_row0_col6\" class=\"data row0 col6\" >0.991000</td>\n",
" <td id=\"T_2c43c_row0_col7\" class=\"data row0 col7\" >0.975000</td>\n",
" <td id=\"T_2c43c_row0_col8\" class=\"data row0 col8\" >0.990000</td>\n",
" <td id=\"T_2c43c_row0_col9\" class=\"data row0 col9\" >0.961000</td>\n",
" <td id=\"T_2c43c_row0_col10\" class=\"data row0 col10\" >0.952000</td>\n",
" <td id=\"T_2c43c_row0_col11\" class=\"data row0 col11\" >118</td>\n",
" <td id=\"T_2c43c_row0_col12\" class=\"data row0 col12\" >64</td>\n",
" <td id=\"T_2c43c_row0_col13\" class=\"data row0 col13\" >6</td>\n",
" <td id=\"T_2c43c_row0_col14\" class=\"data row0 col14\" >0</td>\n",
" </tr>\n",
" </tbody></table>"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f0c94808f10>"
"<pandas.io.formats.style.Styler at 0x7f2f4aab8990>"
]
},
"metadata": {},
Expand All @@ -228,7 +228,8 @@
],
"source": [
"example1 = BinaryClassificationMetrics(y_true, y_pred_proba, precision_digits=3)\n",
"example1.plot()"
"example1.plot(figsize=(12, 12),\n",
" save_path=None)"
]
},
{
Expand Down
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes
2 changes: 1 addition & 1 deletion slickml/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.2"
__version__ = "0.1.3"
60 changes: 42 additions & 18 deletions slickml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def plot_feature_importance(
markerfacecolor=None,
markeredgewidth=None,
fontsize=None,
save_path=None,
):

"""Function to plot XGBoost feature importance.
Expand Down Expand Up @@ -347,18 +348,23 @@ def plot_feature_importance(
fontsize: int or float, optional, (default=12)
Fontsize for xlabel and ylabel, and ticks parameters
save_path: str, optional (default=None)
The full or relative path to save the plot including the image format.
For example "myplot.png" or "../../myplot.pdf"
"""

plot_xgb_feature_importance(
self.feature_importance_,
figsize,
color,
marker,
markersize,
markeredgecolor,
markerfacecolor,
markeredgewidth,
fontsize,
feature_importance=self.feature_importance_,
figsize=figsize,
color=color,
marker=marker,
markersize=markersize,
markeredgecolor=markeredgecolor,
markerfacecolor=markerfacecolor,
markeredgewidth=markeredgewidth,
fontsize=fontsize,
save_path=save_path,
)

def plot_shap_summary(
Expand All @@ -377,6 +383,7 @@ def plot_shap_summary(
class_names=None,
class_inds=None,
color_bar_label=None,
save_path=None,
):
"""Function to plot shap summary plot.
This function is a helper function to plot the shap summary plot
Expand Down Expand Up @@ -433,6 +440,10 @@ def plot_shap_summary(
color_bar_label: str, optional, (default="Feature Value")
Label for color bar
save_path: str, optional (default=None)
The full or relative path to save the plot including the image format.
For example "myplot.png" or "../../myplot.pdf"
"""

# define tree explainer
Expand Down Expand Up @@ -466,6 +477,7 @@ def plot_shap_summary(
class_names=class_names,
class_inds=class_inds,
color_bar_label=color_bar_label,
save_path=save_path,
)

def plot_shap_waterfall(
Expand All @@ -483,6 +495,7 @@ def plot_shap_waterfall(
max_display=None,
title=None,
fontsize=None,
save_path=None,
):
"""Function to plot shap waterfall plot.
This function is a helper function to plot the shap waterfall plot
Expand Down Expand Up @@ -536,6 +549,10 @@ def plot_shap_waterfall(
fontsize: int or float, optional, (default=12)
Fontsize for xlabel and ylabel, and ticks parameters
save_path: str, optional (default=None)
The full or relative path to save the plot including the image format.
For example "myplot.png" or "../../myplot.pdf"
"""

# define tree explainer
Expand Down Expand Up @@ -568,6 +585,7 @@ def plot_shap_waterfall(
max_display=max_display,
title=title,
fontsize=fontsize,
save_path=save_path,
)

def _dtrain(self, X_train, y_train):
Expand Down Expand Up @@ -1026,6 +1044,7 @@ def plot_cv_results(
train_std_color=None,
test_color=None,
test_std_color=None,
save_path=None,
):
"""
Function to plot the results of xgboost.cv() process and evolution
Expand Down Expand Up @@ -1060,18 +1079,23 @@ def plot_cv_results(
test_std_color: str, optional, (default="#D0AAF3")
Color of the edge color of the testing std bars
save_path: str, optional (default=None)
The full or relative path to save the plot including the image format.
For example "myplot.png" or "../../myplot.pdf"
"""

plot_xgb_cv_results(
self.cv_results_,
figsize,
linestyle,
train_label,
test_label,
train_color,
train_std_color,
test_color,
test_std_color,
cv_results=self.cv_results_,
figsize=figsize,
linestyle=linestyle,
train_label=train_label,
test_label=test_label,
train_color=train_color,
train_std_color=train_std_color,
test_color=test_color,
test_std_color=test_std_color,
save_path=save_path,
)

def _cv(self):
Expand Down
42 changes: 30 additions & 12 deletions slickml/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def plot_frequency(
markerfacecolor=None,
markeredgewidth=None,
fontsize=None,
save_path=None,
):

"""Function to plot selected features frequency.
Expand Down Expand Up @@ -614,20 +615,27 @@ def plot_frequency(
"""

plot_xfs_feature_frequency(
self.feature_frequency_,
figsize,
freq_pct,
color,
marker,
markersize,
markeredgecolor,
markerfacecolor,
markeredgewidth,
fontsize,
freq=self.feature_frequency_,
figsize=figsize,
freq_pct=freq_pct,
color=color,
marker=marker,
markersize=markersize,
markeredgecolor=markeredgecolor,
markerfacecolor=markerfacecolor,
markeredgewidth=markeredgewidth,
fontsize=fontsize,
save_path=save_path,
)

def plot_cv_results(
self, figsize=None, int_color=None, ext_color=None, sharex=False, sharey=False
self,
figsize=None,
int_color=None,
ext_color=None,
sharex=False,
sharey=False,
save_path=None,
):
"""Function to plot the cross-validation results of
XGBoostFeatureSelector. It visualizes the internal
Expand Down Expand Up @@ -655,12 +663,22 @@ def plot_cv_results(
sharey: bool, optional, (default=False)
Flag to share "Y" axis for each row of subplots
save_path: str, optional (default=None)
The full or relative path to save the plot including the image format.
For example "myplot.png" or "../../myplot.pdf"
kwargs: dict
Plotting object plotting_cv_
"""

plot_xfs_cv_results(
figsize, int_color, ext_color, sharex, sharey, **self.plotting_cv_
figsize=figsize,
int_color=int_color,
ext_color=ext_color,
sharex=sharex,
sharey=sharey,
save_path=save_path,
**self.plotting_cv_,
)

def get_xgb_params(self):
Expand Down
Loading

0 comments on commit 8817eca

Please sign in to comment.