-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
79 changed files
with
5,459 additions
and
87 deletions.
There are no files selected for viewing
54 changes: 54 additions & 0 deletions
54
function/python/brightics/function/classification/test/ada_boost_classification_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" | ||
Copyright 2019 Samsung SDS | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
|
||
from brightics.function.classification.ada_boost_classification import ada_boost_classification_train | ||
from brightics.function.classification.ada_boost_classification import ada_boost_classification_predict | ||
from brightics.common.datasets import load_iris | ||
import unittest | ||
import pandas as pd | ||
import numpy as np | ||
|
||
|
||
class ADABoostClassification(unittest.TestCase): | ||
|
||
def setUp(self): | ||
print("*** ADA Boost Classification UnitTest Start ***") | ||
self.testdata = load_iris() | ||
|
||
def tearDown(self): | ||
print("*** ADA Boost Classification UnitTest End ***") | ||
|
||
def test(self): | ||
ada_train = ada_boost_classification_train(self.testdata, feature_cols=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], label_col='species', random_state=12345) | ||
ada_model = ada_train['model']['classifier'] | ||
estimators = ada_model.estimators_ if hasattr(ada_model, 'estimators_') else None | ||
classes = ada_model.classes_ if hasattr(ada_model, 'classes_') else None | ||
n_classes = ada_model.n_classes_ if hasattr(ada_model, 'n_classes_') else None | ||
estimator_weights = ada_model.estimator_weights_ if hasattr(ada_model, 'estimator_weights_') else None | ||
estimator_errors = ada_model.estimator_errors_ if hasattr(ada_model, 'estimator_errors_') else None | ||
feature_importances = ada_model.feature_importances_ if hasattr(ada_model, 'feature_importances_') else None | ||
|
||
np.testing.assert_array_equal([round(x, 15) for x in estimator_weights], [1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000,1.000000000000000]) | ||
np.testing.assert_array_equal([round(x, 15) for x in estimator_errors], [0.333333333333333,0.060002258916742,0.333355372258786,0.120397413941043,0.333354186061486,0.000246496771993,0.333379977227275,0.000324696290229,0.333365707421006,0.000377301321863,0.333379703755078,0.000349084998612,0.333368068990284,0.000377542088862,0.333377933238813,0.000352106375872,0.333369432842104,0.000372712089944,0.333376615861411,0.000354056001133,0.333370427040033,0.000369265051501,0.333375673592510,0.000355694330200,0.333371161245900,0.000366896338034,0.333374996123643,0.000357017448219,0.333371702726590,0.000365249477557,0.333374506534519,0.000358050974252,0.333372101321344,0.000364091056983,0.333374151478965,0.000358841484973,0.333372394265748,0.000363268581665,0.333373893375635,0.000359437901881,0.333372609287275,0.000362680377430,0.333373705442854,0.000359883750645,0.333372766955686,0.000362257368360,0.333373568446543,0.000360214927717,0.333372882480536,0.000361951872769]) | ||
np.testing.assert_array_equal([round(x, 15) for x in feature_importances], [0.000000000000000,0.000000000000000,0.440000000000000,0.560000000000000]) | ||
|
||
predict = ada_boost_classification_predict(self.testdata, ada_train['model']) | ||
prob1 = predict['out_table']['probability_1'] | ||
prob2 = predict['out_table']['probability_2'] | ||
np.testing.assert_array_equal([round(x, 15) for x in prob1], [0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.000006788589536,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.449747463798886,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.024737382237678,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.535991707392255,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.976767518408969,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.480033074405363,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.535991707392255,0.021963346892192,0.449747463798886,0.021963346892192,0.449747463798886,0.021963346892192,0.021963346892192,0.449747463798886,0.449747463798886,0.021963346892192,0.535991707392255,0.021963346892192,0.021963346892192,0.021963346892192,0.535991707392255,0.535991707392255,0.021963346892192,0.021963346892192,0.021963346892192,0.449747463798886,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192,0.021963346892192]) | ||
np.testing.assert_array_equal([round(x, 15) for x in prob2], [0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.000000161467062,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.550252536200809,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.975262617762322,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.464008292607232,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.023232481545849,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.519966925594152,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.464008292607232,0.978036653107808,0.550252536200809,0.978036653107808,0.550252536200809,0.978036653107808,0.978036653107808,0.550252536200809,0.550252536200809,0.978036653107808,0.464008292607232,0.978036653107808,0.978036653107808,0.978036653107808,0.464008292607232,0.464008292607232,0.978036653107808,0.978036653107808,0.978036653107808,0.550252536200809,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808,0.978036653107808]) | ||
|
65 changes: 65 additions & 0 deletions
65
function/python/brightics/function/classification/test/decision_tree_classification_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
Copyright 2019 Samsung SDS | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
from brightics.common.datasets import load_iris | ||
from brightics.function.classification.decision_tree_classification import decision_tree_classification_train, \ | ||
decision_tree_classification_predict | ||
|
||
|
||
class DecisionTreeClassificationTest(unittest.TestCase): | ||
|
||
def setUp(self): | ||
print("*** Decision Tree Classification UnitTest Start ***") | ||
self.iris = load_iris() | ||
|
||
def tearDown(self): | ||
print("*** Decision Tree Classification UnitTest End ***") | ||
|
||
def test_decision_tree_classification1(self): | ||
train_out = decision_tree_classification_train(table=self.iris, feature_cols=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], label_col='species', random_state=12345) | ||
|
||
table = train_out['model']['feature_importance'] | ||
self.assertAlmostEqual(table[0], 0.02666667, 6) | ||
self.assertAlmostEqual(table[1], 0.0, 1) | ||
self.assertAlmostEqual(table[2], 0.55072262, 6) | ||
self.assertAlmostEqual(table[3], 0.42261071, 6) | ||
|
||
def test_decision_tree_classification2(self): | ||
train_out = decision_tree_classification_train(table=self.iris, feature_cols=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], label_col='species', random_state=12345) | ||
predict_out = decision_tree_classification_predict(table=self.iris, model=train_out['model']) | ||
|
||
table = predict_out['out_table'].values.tolist() | ||
self.assertListEqual(table[0], [5.1, 3.5, 1.4, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[1], [4.9, 3.0, 1.4, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[2], [4.7, 3.2, 1.3, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[3], [4.6, 3.1, 1.5, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[4], [5.0, 3.6, 1.4, 0.2, 'setosa', 'setosa']) | ||
|
||
def test_decision_tree_classification3(self): | ||
train_out = decision_tree_classification_train(table=self.iris, feature_cols=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], label_col='species', random_state=12345, criterion='entropy', max_leaf_nodes=2, group_by=['species']) | ||
predict_out = decision_tree_classification_predict(table=self.iris, model=train_out['model']) | ||
|
||
table = predict_out['out_table'].values.tolist() | ||
self.assertListEqual(table[0], [5.1, 3.5, 1.4, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[1], [4.9, 3.0, 1.4, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[2], [4.7, 3.2, 1.3, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[3], [4.6, 3.1, 1.5, 0.2, 'setosa', 'setosa']) | ||
self.assertListEqual(table[4], [5.0, 3.6, 1.4, 0.2, 'setosa', 'setosa']) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
48 changes: 48 additions & 0 deletions
48
function/python/brightics/function/classification/test/knn_classification_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
Copyright 2019 Samsung SDS | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
import numpy as np | ||
from sklearn.model_selection import train_test_split | ||
from brightics.common.datasets import load_iris | ||
from brightics.function.classification import knn_classification | ||
|
||
|
||
class TestKNNClassification(unittest.TestCase): | ||
|
||
def test_default(self): | ||
df_iris = load_iris() | ||
df_train, df_test = train_test_split(df_iris, random_state=12345) | ||
df_res = knn_classification(train_table=df_train, test_table=df_test, | ||
feature_cols=['sepal_length', 'sepal_width'], label_col='species', | ||
k=5, algorithm='auto', leaf_size=30, p=2)['out_table'] | ||
|
||
self.assertListEqual(['versicolor', 'setosa', 'virginica', 'setosa', 'setosa'], df_res['prediction'].tolist()[:5], 'incorrect prediction') | ||
np.testing.assert_array_almost_equal([0.0, 1.0, 0.0, 1.0, 1.0], df_res['probability_0'].values[:5], 7, 'incorrect probability_0') | ||
np.testing.assert_array_almost_equal([0.8, 0.0, 0.2, 0.0, 0.0], df_res['probability_1'].values[:5], 7, 'incorrect probability_1') | ||
np.testing.assert_array_almost_equal([0.2, 0.0, 0.8, 0.0, 0.0], df_res['probability_2'].values[:5], 7, 'incorrect probability_2') | ||
|
||
def test_optional(self): | ||
df_iris = load_iris() | ||
df_train, df_test = train_test_split(df_iris, random_state=12345) | ||
df_res = knn_classification(train_table=df_train, test_table=df_test, | ||
feature_cols=['sepal_length', 'sepal_width', 'petal_length'], label_col='species', | ||
k=10, algorithm='auto', leaf_size=30, p=2)['out_table'] | ||
|
||
self.assertListEqual(['versicolor', 'setosa', 'versicolor', 'setosa', 'setosa'], df_res['prediction'].tolist()[:5], 'incorrect prediction') | ||
np.testing.assert_array_almost_equal([0.0, 1.0, 0.0, 1.0, 1.0], df_res['probability_0'].values[:5], 7, 'incorrect probability_0') | ||
np.testing.assert_array_almost_equal([1.0, 0.0, 0.7, 0.0, 0.0], df_res['probability_1'].values[:5], 7, 'incorrect probability_1') | ||
np.testing.assert_array_almost_equal([0.0, 0.0, 0.3, 0.0, 0.0], df_res['probability_2'].values[:5], 7, 'incorrect probability_2') |
Oops, something went wrong.