Skip to content

Commit

Permalink
[GH-5720] AIC/Loglikelihood test fixes (#5723)
Browse files Browse the repository at this point in the history
* fix metrics generation (#5718)

(cherry picked from commit 3af804b)

* [GH-5720] Add AIC/Loglikelihood metrics to MetricsTestSuite (#5721)

* fix metrics generation (#5718)

(cherry picked from commit 3af804b)

* [GH-5720] Add AIC/Loglikelihood metrics to MetricsTestSuite

(cherry picked from commit b3bef8b)

* [GH-5720] fix H2oGridSearchTestSuite NaN != NaN

* [GH-5720] fix H2oGridSearchTestSuite NaN != NaN

* aic/loglikelihood optional to be backwards compatible

* fix metrics in test_gridsearch.py

* fix test_mojo.py nan == nan
  • Loading branch information
krasinski committed Apr 8, 2024
1 parent b3bef8b commit edbd884
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
Expand Up @@ -21,5 +21,6 @@ object MetricFieldExceptions {
def ignored(): Set[String] =
Set("__meta", "domain", "model", "model_checksum", "frame", "frame_checksum", "model_category", "predictions")

def optional(): Set[String] = Set("custom_metric_name", "custom_metric_value", "mean_score", "mean_normalized_score")
def optional(): Set[String] =
Set("custom_metric_name", "custom_metric_value", "mean_score", "mean_normalized_score", "AIC", "loglikelihood")
}
Expand Up @@ -23,10 +23,9 @@ import ai.h2o.sparkling.{SharedH2OTestContext, TestUtils}
import hex.Model
import hex.tree.gbm.GBMModel.GBMParameters
import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.SparkSession
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FunSuite, Matchers}
Expand Down Expand Up @@ -268,6 +267,8 @@ class H2OGridSearchTestSuite extends FunSuite with Matchers with SharedH2OTestCo
}

test("The first row returned by getGridModelsMetrics() method is the same as current metrics of the best model") {
import spark.implicits._

val drf = new H2ODRF()
.setFeaturesCols(Array("AGE", "RACE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON"))
.setLabelCol("CAPSULE")
Expand All @@ -288,7 +289,8 @@ class H2OGridSearchTestSuite extends FunSuite with Matchers with SharedH2OTestCo
val gridModelMetrics = search.getGridModelsMetrics().drop("MOJO Model ID")
val modelMetricsFromGrid = gridModelMetrics.columns.zip(gridModelMetrics.head().toSeq).toMap

modelMetricsFromGrid shouldEqual expectedMetrics
modelMetricsFromGrid.filter(!_._2.asInstanceOf[Double].isNaN) should contain theSameElementsAs expectedMetrics
.filter(!_._2.isNaN)
}

test("The first row returned by getGridModelsParams() method is the same as training params of the best model") {
Expand Down
2 changes: 1 addition & 1 deletion py/tests/unit/with_runtime_sparkling/test_gridsearch.py
Expand Up @@ -109,7 +109,7 @@ def testGetGridModelsMetrics(prostateDataset):
grid.fit(prostateDataset)
metrics = grid.getGridModelsMetrics()
assert metrics.count() == 3
expectedCols = ['MOJO Model ID', 'RMSLE', 'Nobs', 'RMSE', 'MAE', 'MeanResidualDeviance', 'ScoringTime', 'MSE', 'R2']
expectedCols = ['MOJO Model ID', 'RMSLE', 'Nobs', 'RMSE', 'MAE', 'MeanResidualDeviance', 'ScoringTime', "Loglikelihood", 'MSE', 'R2', 'AIC']
assert metrics.columns == expectedCols
metrics.collect() # try materializing

Expand Down
6 changes: 4 additions & 2 deletions py/tests/unit/with_runtime_sparkling/test_mojo.py
Expand Up @@ -19,6 +19,7 @@
import shutil
import unit_test_utils
import os
import math

from pyspark.mllib.linalg import *
from pyspark.sql.types import *
Expand Down Expand Up @@ -63,7 +64,7 @@ def testModelCategory(gbmModel):
def testTrainingMetrics(gbmModel):
metrics = gbmModel.getTrainingMetrics()
assert metrics is not None
assert len(metrics) is 10
assert len(metrics) is 12


def testFeatureTypes(gbmModel):
Expand Down Expand Up @@ -245,7 +246,8 @@ def compareMetricValues(metricsObject, metricsMap):
for metric in metricsMap:
metricValue = metricsMap[metric]
objectValue = getattr(metricsObject, "get" + metric)()
assert(metricValue == objectValue)
if not math.isnan(metricValue) and not math.isnan(objectValue):
assert(metricValue == objectValue)
assert metricsObject.getConfusionMatrix().count() > 0
assert len(metricsObject.getConfusionMatrix().columns) > 0
assert metricsObject.getGainsLiftTable().count() > 0
Expand Down

0 comments on commit edbd884

Please sign in to comment.