Skip to content

Commit

Permalink
Added the numFeatures parameter to the IsolationForestModel class (in…
Browse files Browse the repository at this point in the history
…cl. saved model metadata). (#32)
  • Loading branch information
jverbus committed Jun 7, 2022
1 parent 9ab34bd commit ec66605
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 9 deletions.
Expand Up @@ -128,7 +128,7 @@ class IsolationForest(override val uid: String) extends Estimator[IsolationFores
}).collect()

val isolationForestModel = copyValues(
new IsolationForestModel(uid, isolationTrees, numSamples).setParent(this))
new IsolationForestModel(uid, isolationTrees, numSamples, numFeatures).setParent(this))

// Determine and set the model threshold based upon the specified contamination and
// contaminationError parameters.
Expand Down
Expand Up @@ -16,16 +16,24 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
*
* @param uid The immutable unique ID for the model.
* @param isolationTrees The array of isolation tree models that compose the isolation forest.
* @param numSamples The number of samples used to train each tree.
* @param numFeatures The user-specified number of features used to train each isolation tree. For certain edge cases,
* a given isolation tree may not have any nodes using some of these features, e.g., a shallow tree
* where the number of features in the training data exceeds the number of nodes in the tree.
*/
class IsolationForestModel(
override val uid: String,
val isolationTrees: Array[IsolationTree],
private val numSamples: Int)
private val numSamples: Int,
private val numFeatures: Int)
extends Model[IsolationForestModel] with IsolationForestParams with MLWritable {

require(numSamples > 0, s"parameter numSamples must be >0, but given invalid value ${numSamples}")
final def getNumSamples: Int = numSamples

require(numFeatures > 0, s"parameter numFeatures must be >0, but given invalid value ${numFeatures}")
final def getNumFeatures: Int = numFeatures

// The outlierScoreThreshold needs to be a mutable variable because it is not known when an
// IsolationForestModel instance is created.
private var outlierScoreThreshold: Double = -1
Expand All @@ -40,7 +48,7 @@ class IsolationForestModel(

override def copy(extra: ParamMap): IsolationForestModel = {

val isolationForestCopy = new IsolationForestModel(uid, isolationTrees, numSamples)
val isolationForestCopy = new IsolationForestModel(uid, isolationTrees, numSamples, numFeatures)
.setParent(this.parent)
isolationForestCopy.setOutlierScoreThreshold(outlierScoreThreshold)
copyValues(isolationForestCopy, extra)
Expand Down
Expand Up @@ -49,14 +49,15 @@ private[isolationforest] case object IsolationForestModelReadWrite extends Loggi
implicit val format = DefaultFormats
val (metadata, treesData) = loadImpl(path, sparkSession)
val numSamples = (metadata.metadata \ "numSamples").extract[Int]
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val outlierScoreThreshold = (metadata.metadata \ "outlierScoreThreshold").extract[Double]

val trees = treesData.map {
case internalNode: InternalNode => new IsolationTree(internalNode.asInstanceOf[InternalNode])
case externalNode: ExternalNode => new IsolationTree(externalNode.asInstanceOf[ExternalNode])
}

val model = new IsolationForestModel(metadata.uid, trees, numSamples)
val model = new IsolationForestModel(metadata.uid, trees, numSamples, numFeatures)
metadata.setParams(model)
model.setOutlierScoreThreshold(outlierScoreThreshold)

Expand Down Expand Up @@ -237,7 +238,8 @@ private[isolationforest] case object IsolationForestModelReadWrite extends Loggi

val extraMetadata: JObject =
("outlierScoreThreshold", instance.getOutlierScoreThreshold) ~
("numSamples", instance.getNumSamples)
("numSamples", instance.getNumSamples) ~
("numFeatures", instance.getNumFeatures)
saveImplHelper(path, sparkSession, extraMetadata)
}

Expand All @@ -246,7 +248,7 @@ private[isolationforest] case object IsolationForestModelReadWrite extends Loggi
*
* @param path The path on disk used to save the ensemble model.
* @param spark The SparkSession instance to use.
* @param extraMetadata Metadata such as outlierScoreThreshold and numSamples.
* @param extraMetadata Metadata such as outlierScoreThreshold, numSamples, and numFeatures.
*/
private def saveImplHelper(path: String, spark: SparkSession, extraMetadata: JObject): Unit = {

Expand Down
@@ -1 +1 @@
{"class":"com.linkedin.relevance.isolationforest.IsolationForestModel","timestamp":1544084998332,"sparkVersion":"2.3.0.89","uid":"isolation-forest_746c9083c2c1","paramMap":{"predictionCol":"predictedLabel","maxFeatures":1.0,"scoreCol":"outlierScore","maxSamples":256.0,"randomSeed":1,"bootstrap":false,"contamination":0.02,"featuresCol":"features","numEstimators":100},"outlierScoreThreshold":0.6015323679815825,"numSamples":256}
{"class":"com.linkedin.relevance.isolationforest.IsolationForestModel","timestamp":1544084998332,"sparkVersion":"2.3.0.89","uid":"isolation-forest_746c9083c2c1","paramMap":{"predictionCol":"predictedLabel","maxFeatures":1.0,"scoreCol":"outlierScore","maxSamples":256.0,"randomSeed":1,"bootstrap":false,"contamination":0.02,"featuresCol":"features","numEstimators":100},"outlierScoreThreshold":0.6015323679815825,"numSamples":256,"numFeatures":6}
Expand Up @@ -47,6 +47,7 @@ class IsolationForestModelWriteReadTest extends Logging {
isolationForestModel1.extractParamMap.toString,
isolationForestModel2.extractParamMap.toString)
Assert.assertEquals(isolationForestModel1.getNumSamples, isolationForestModel2.getNumSamples)
Assert.assertEquals(isolationForestModel1.getNumFeatures, isolationForestModel2.getNumFeatures)
Assert.assertEquals(
isolationForestModel1.getOutlierScoreThreshold,
isolationForestModel2.getOutlierScoreThreshold)
Expand Down Expand Up @@ -110,6 +111,7 @@ class IsolationForestModelWriteReadTest extends Logging {
isolationForestModel1.extractParamMap.toString,
isolationForestModel2.extractParamMap.toString)
Assert.assertEquals(isolationForestModel1.getNumSamples, isolationForestModel2.getNumSamples)
Assert.assertEquals(isolationForestModel1.getNumFeatures, isolationForestModel2.getNumFeatures)
Assert.assertEquals(
isolationForestModel1.getOutlierScoreThreshold,
isolationForestModel2.getOutlierScoreThreshold)
Expand Down Expand Up @@ -207,7 +209,7 @@ class IsolationForestModelWriteReadTest extends Logging {
val spark = getSparkSession

// Create an isolation forest model with no isolation trees
val isolationForestModel1 = new IsolationForestModel("testUid", Array(), 1)
val isolationForestModel1 = new IsolationForestModel("testUid", Array(), numSamples = 1, numFeatures = 2)
isolationForestModel1.setOutlierScoreThreshold(0.5)

// Write the trained model to disk and then read it back from disk
Expand All @@ -221,6 +223,7 @@ class IsolationForestModelWriteReadTest extends Logging {
isolationForestModel1.extractParamMap.toString,
isolationForestModel2.extractParamMap.toString)
Assert.assertEquals(isolationForestModel1.getNumSamples, isolationForestModel2.getNumSamples)
Assert.assertEquals(isolationForestModel1.getNumFeatures, isolationForestModel2.getNumFeatures)
Assert.assertEquals(
isolationForestModel1.getOutlierScoreThreshold,
isolationForestModel2.getOutlierScoreThreshold)
Expand Down
2 changes: 1 addition & 1 deletion version.properties
@@ -1,3 +1,3 @@
# Version of the produced binaries.
# The version is inferred by shipkit-auto-version Gradle plugin (https://github.com/shipkit/shipkit-auto-version).
version=2.0.*
version=3.0.*

0 comments on commit ec66605

Please sign in to comment.