Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add APPROX_HISTOGRAM_K Operation #735

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@

package ai.chronon.aggregator.base

import ai.chronon.aggregator.base.FrequentItemType.ItemType
import ai.chronon.api._
import com.yahoo.memory.Memory
import com.yahoo.sketches.cpc.{CpcSketch, CpcUnion}
import com.yahoo.sketches.frequencies.{ErrorType, ItemsSketch}
import com.yahoo.sketches.kll.KllFloatsSketch
import com.yahoo.sketches.{
ArrayOfDoublesSerDe,
ArrayOfItemsSerDe,
ArrayOfLongsSerDe,
ArrayOfNumbersSerDe,
ArrayOfStringsSerDe
}
import com.yahoo.sketches.{ArrayOfDoublesSerDe, ArrayOfItemsSerDe, ArrayOfLongsSerDe, ArrayOfStringsSerDe}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
import java.util
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

class Sum[I: Numeric](inputType: DataType) extends SimpleAggregator[I, I, I] {
Expand Down Expand Up @@ -343,67 +341,261 @@ object CpcFriendly {
}
}

object FrequentItemType extends Enumeration {
type ItemType = Value
val StringItemType, LongItemType, DoubleItemType = Value
}

case class ItemsSketchIR[T](sketch: ItemsSketch[T], sketchType: ItemType)

trait FrequentItemsFriendly[Input] {
def serializer: ArrayOfItemsSerDe[Input]
def sketchType: FrequentItemType.ItemType
}

object FrequentItemsFriendly {
implicit val stringIsFrequentItemsFriendly: FrequentItemsFriendly[String] = new FrequentItemsFriendly[String] {
override def serializer: ArrayOfItemsSerDe[String] = new ArrayOfStringsSerDe
override def sketchType: ItemType = FrequentItemType.StringItemType
}

implicit val longIsCpcFriendly: FrequentItemsFriendly[java.lang.Long] = new FrequentItemsFriendly[java.lang.Long] {
override def serializer: ArrayOfItemsSerDe[java.lang.Long] = new ArrayOfLongsSerDe
}
implicit val doubleIsCpcFriendly: FrequentItemsFriendly[java.lang.Double] =
implicit val longIsFrequentItemsFriendly: FrequentItemsFriendly[java.lang.Long] =
new FrequentItemsFriendly[java.lang.Long] {
override def serializer: ArrayOfItemsSerDe[java.lang.Long] = new ArrayOfLongsSerDe
override def sketchType: ItemType = FrequentItemType.LongItemType
}

implicit val doubleIsFrequentItemsFriendly: FrequentItemsFriendly[java.lang.Double] =
new FrequentItemsFriendly[java.lang.Double] {
override def serializer: ArrayOfItemsSerDe[java.lang.Double] = new ArrayOfDoublesSerDe
override def sketchType: ItemType = FrequentItemType.DoubleItemType
}

implicit val BinaryIsCpcFriendly: FrequentItemsFriendly[Number] = new FrequentItemsFriendly[Number] {
override def serializer: ArrayOfItemsSerDe[Number] = new ArrayOfNumbersSerDe
}
}

class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: ErrorType = ErrorType.NO_FALSE_POSITIVES)
extends SimpleAggregator[T, ItemsSketch[T], Map[T, Long]] {
override def outputType: DataType = MapType(StringType, IntType)
extends SimpleAggregator[T, ItemsSketchIR[T], util.Map[String, Long]] {
private type Sketch = ItemsSketchIR[T]

// The ItemsSketch implementation requires a size with a positive power of 2
// Initialize the sketch with the next closest power of 2
val sketchSize: Int = if (mapSize > 1) Integer.highestOneBit(mapSize - 1) << 1 else 2

override def outputType: DataType = MapType(StringType, LongType)

override def irType: DataType = BinaryType
type Sketch = ItemsSketch[T]
override def prepare(input: T): Sketch = {
val sketch = new ItemsSketch[T](mapSize)

override def prepare(input: T): ItemsSketchIR[T] = {
val sketch = new ItemsSketch[T](sketchSize)
val sketchType = implicitly[FrequentItemsFriendly[T]].sketchType
sketch.update(input)
sketch
ItemsSketchIR(sketch, sketchType)
}

override def update(ir: Sketch, input: T): Sketch = {
ir.update(input)
ir.sketch.update(input)
ir
}
override def merge(ir1: Sketch, ir2: Sketch): Sketch = {
ir1.merge(ir2)
ir1.sketch.merge(ir2.sketch)
ir1
}

// ItemsSketch doesn't have a proper copy method. So we serialize and deserialize.
override def clone(ir: Sketch): Sketch = {
val serDe = implicitly[FrequentItemsFriendly[T]].serializer
val bytes = ir.toByteArray(serDe)
ItemsSketch.getInstance[T](Memory.wrap(bytes), serDe)
val serializer = implicitly[FrequentItemsFriendly[T]].serializer
val bytes = ir.sketch.toByteArray(serializer)
val clonedSketch = ItemsSketch.getInstance(Memory.wrap(bytes), serializer)
ItemsSketchIR(clonedSketch, ir.sketchType)
}

override def finalize(ir: Sketch): Map[T, Long] =
ir.getFrequentItems(errorType).map(sk => sk.getItem -> sk.getEstimate).toMap
override def finalize(ir: Sketch): util.Map[String, Long] = {
if (mapSize <= 0) {
return new util.HashMap[String, Long]()
}

val items = ir.sketch.getFrequentItems(errorType).map(sk => sk.getItem -> sk.getEstimate)
val heap = mutable.PriorityQueue[(T, Long)]()(Ordering.by(_._2))

items.foreach({
case (key, value) =>
if (heap.size < mapSize) {
heap.enqueue((key, value))
} else if (heap.head._2 < value) {
heap.dequeue()
heap.enqueue((key, value))
}
})

val result = new util.HashMap[String, Long]()
val entries = heap.dequeueAll.toList
entries.foreach({ case (k, v) => result.put(String.valueOf(k), v) })
result
}

override def normalize(ir: Sketch): Array[Byte] = {
val serDe = implicitly[FrequentItemsFriendly[T]].serializer
ir.toByteArray(serDe)
val serializer = implicitly[FrequentItemsFriendly[T]].serializer
(Seq(ir.sketchType.id.byteValue()) ++ ir.sketch.toByteArray(serializer)).toArray
}

override def denormalize(normalized: Any): Sketch = {
val serDe = implicitly[FrequentItemsFriendly[T]].serializer
ItemsSketch.getInstance[T](Memory.wrap(normalized.asInstanceOf[Array[Byte]]), serDe)
val bytes = normalized.asInstanceOf[Array[Byte]]
val sketchType = FrequentItemType(bytes.head)
val serializer = implicitly[FrequentItemsFriendly[T]].serializer
val sketch = ItemsSketch.getInstance[T](Memory.wrap(bytes.tail), serializer)
ItemsSketchIR(sketch, sketchType)
}

def toSketch(values: util.Map[T, Long]): Sketch = {
val sketch = new ItemsSketch[T](sketchSize)
val sketchType = implicitly[FrequentItemsFriendly[T]].sketchType

values.asScala.foreach({ case (k, v) => sketch.update(k, v) })

ItemsSketchIR(sketch, sketchType)
}
}

case class ApproxHistogramIr[T: FrequentItemsFriendly](
isApprox: Boolean,
sketch: Option[ItemsSketchIR[T]],
histogram: Option[util.Map[T, Long]]
)

case class ApproxHistogramIrSerializable[T: FrequentItemsFriendly](
isApprox: Boolean,
// The ItemsSketch isn't directly serializable
sketch: Option[Array[Byte]],
histogram: Option[util.Map[T, Long]]
)

// The ItemsSketch uses approximations and estimates for both values below and above k.
// This keeps an exact aggregation for entries where the number of keys is < k, and switches over to the sketch
// when the underlying map exceeds k keys.
class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorType = ErrorType.NO_FALSE_POSITIVES)
extends SimpleAggregator[T, ApproxHistogramIr[T], util.Map[String, Long]] {
private val frequentItemsAggregator = new FrequentItems[T](mapSize, errorType)
override def prepare(input: T): ApproxHistogramIr[T] = {
val histogram = new util.HashMap[T, Long]()
histogram.put(input, 1L)
ApproxHistogramIr(isApprox = false, sketch = None, histogram = Some(histogram))
}

override def update(ir: ApproxHistogramIr[T], input: T): ApproxHistogramIr[T] = {
(ir.histogram, ir.sketch) match {
case (Some(hist), _) =>
increment(input, 1L, hist)
toIr(hist)
case (_, Some(sketch)) =>
sketch.sketch.update(input)
ApproxHistogramIr(isApprox = true, sketch = Some(sketch), histogram = None)
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def outputType: DataType = MapType(StringType, LongType)
override def irType: DataType = BinaryType

override def merge(ir1: ApproxHistogramIr[T], ir2: ApproxHistogramIr[T]): ApproxHistogramIr[T] = {
(ir1.histogram, ir1.sketch, ir2.histogram, ir2.sketch) match {
case (Some(hist1), None, Some(hist2), None) => combine(hist1, hist2)
case (None, Some(sketch1), None, Some(sketch2)) => combine(sketch1, sketch2)
case (Some(hist1), None, None, Some(sketch2)) => combine(hist1, sketch2)
case (None, Some(sketch1), Some(hist2), None) => combine(hist2, sketch1)
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def finalize(ir: ApproxHistogramIr[T]): util.Map[String, Long] = {
(ir.sketch, ir.histogram) match {
case (Some(sketch), None) => frequentItemsAggregator.finalize(sketch)
case (None, Some(hist)) => toOutputMap(hist)
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def clone(ir: ApproxHistogramIr[T]): ApproxHistogramIr[T] = {
(ir.sketch, ir.histogram) match {
case (Some(sketch), None) =>
val clone = frequentItemsAggregator.clone(sketch)
ApproxHistogramIr(isApprox = true, sketch = Some(clone), histogram = None)
case (None, Some(hist)) =>
val clone = new util.HashMap[T, Long](hist)
ApproxHistogramIr(isApprox = false, sketch = None, histogram = Some(clone))
case _ => throw new IllegalStateException("Histogram state is missing")
}
}

override def normalize(ir: ApproxHistogramIr[T]): Any = {
val serializable = ApproxHistogramIrSerializable(
isApprox = ir.isApprox,
sketch = ir.sketch.map(frequentItemsAggregator.normalize),
histogram = ir.histogram
)

val byteStream = new ByteArrayOutputStream()
val outputStream = new ObjectOutputStream(byteStream)

try {
outputStream.writeObject(serializable)
} finally {
outputStream.close()
byteStream.close()
}

byteStream.toByteArray
}

override def denormalize(ir: Any): ApproxHistogramIr[T] = {
val bytes = ir.asInstanceOf[Array[Byte]]

val byteStream = new ByteArrayInputStream(bytes)
val objectStream = new ObjectInputStream(byteStream)

try {
val serializable = objectStream.readObject().asInstanceOf[ApproxHistogramIrSerializable[T]]
ApproxHistogramIr(
isApprox = serializable.isApprox,
sketch = serializable.sketch.map(frequentItemsAggregator.denormalize),
histogram = serializable.histogram
)
} finally {
objectStream.close()
byteStream.close()
}
}

private def combine(hist1: util.Map[T, Long], hist2: util.Map[T, Long]): ApproxHistogramIr[T] = {
val hist = new util.HashMap[T, Long]()

hist1.asScala.foreach({ case (k, v) => increment(k, v, hist) })
hist2.asScala.foreach({ case (k, v) => increment(k, v, hist) })

toIr(hist)
}
private def combine(sketch1: ItemsSketchIR[T], sketch2: ItemsSketchIR[T]): ApproxHistogramIr[T] = {
val sketch = frequentItemsAggregator.merge(sketch1, sketch2)
ApproxHistogramIr(isApprox = true, sketch = Some(sketch), histogram = None)
}
private def combine(hist: util.Map[T, Long], sketch: ItemsSketchIR[T]): ApproxHistogramIr[T] = {
hist.asScala.foreach({ case (k, v) => sketch.sketch.update(k, v) })
ApproxHistogramIr(isApprox = true, sketch = Some(sketch), histogram = None)
}

private def toIr(hist: util.Map[T, Long]): ApproxHistogramIr[T] = {
if (hist.size > mapSize)
ApproxHistogramIr(isApprox = true, sketch = Some(frequentItemsAggregator.toSketch(hist)), histogram = None)
else
ApproxHistogramIr(isApprox = false, sketch = None, histogram = Some(hist))
}

private def increment(value: T, times: Long, values: util.Map[T, Long]): Unit = {
values.put(value, values.getOrDefault(value, 0) + times)
}

private def toOutputMap(map: util.Map[T, Long]): util.Map[String, Long] = {
val result = new util.HashMap[String, Long](map.size())
map.asScala.foreach({ case (k, v) => result.put(String.valueOf(k), v) })
result
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ object ColumnAggregator {
private def toFloat[A: Numeric](inp: Any): Float = implicitly[Numeric[A]].toFloat(inp.asInstanceOf[A])
private def toLong[A: Numeric](inp: Any) = implicitly[Numeric[A]].toLong(inp.asInstanceOf[A])
private def boolToLong(inp: Any): Long = if (inp.asInstanceOf[Boolean]) 1 else 0
private def toJavaLong[A: Numeric](inp: Any) =
implicitly[Numeric[A]].toLong(inp.asInstanceOf[A]).asInstanceOf[java.lang.Long]
private def toJavaDouble[A: Numeric](inp: Any) =
implicitly[Numeric[A]].toDouble(inp.asInstanceOf[A]).asInstanceOf[java.lang.Double]

def construct(baseInputType: DataType,
aggregationPart: AggregationPart,
Expand Down Expand Up @@ -256,6 +260,17 @@ object ColumnAggregator {
aggregationPart.operation match {
case Operation.COUNT => simple(new Count)
case Operation.HISTOGRAM => simple(new Histogram(aggregationPart.getInt("k", Some(0))))
case Operation.APPROX_HISTOGRAM_K =>
val k = aggregationPart.getInt("k", Some(8))
inputType match {
case IntType => simple(new ApproxHistogram[java.lang.Long](k), toJavaLong[Int])
case LongType => simple(new ApproxHistogram[java.lang.Long](k))
case ShortType => simple(new ApproxHistogram[java.lang.Long](k), toJavaLong[Short])
case DoubleType => simple(new ApproxHistogram[java.lang.Double](k))
case FloatType => simple(new ApproxHistogram[java.lang.Double](k), toJavaDouble[Float])
case StringType => simple(new ApproxHistogram[String](k))
case _ => mismatchException
}
case Operation.SUM =>
inputType match {
case IntType => simple(new Sum[Long](LongType), toLong[Int])
Expand Down