Skip to content

Commit

Permalink
Add APPROX_HISTOGRAM_K Operation
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrooks-stripe committed Mar 29, 2024
1 parent 92a78f1 commit c7769f6
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,10 @@ import com.yahoo.memory.Memory
import com.yahoo.sketches.cpc.{CpcSketch, CpcUnion}
import com.yahoo.sketches.frequencies.{ErrorType, ItemsSketch}
import com.yahoo.sketches.kll.KllFloatsSketch
import com.yahoo.sketches.{
ArrayOfDoublesSerDe,
ArrayOfItemsSerDe,
ArrayOfLongsSerDe,
ArrayOfNumbersSerDe,
ArrayOfStringsSerDe
}
import com.yahoo.sketches.{ArrayOfDoublesSerDe, ArrayOfItemsSerDe, ArrayOfLongsSerDe, ArrayOfStringsSerDe}

import java.util
import java.{lang, util}
import scala.collection.mutable
import scala.reflect.ClassTag

class Sum[I: Numeric](inputType: DataType) extends SimpleAggregator[I, I, I] {
Expand Down Expand Up @@ -343,26 +338,32 @@ object CpcFriendly {
}
}

class StringItemsSketch(maxMapSize: Int) extends ItemsSketch[String](maxMapSize)
class LongItemsSketch(maxMapSize: Int) extends ItemsSketch[java.lang.Long](maxMapSize)
class DoubleItemsSketch(maxMapSize: Int) extends ItemsSketch[java.lang.Double](maxMapSize)

trait FrequentItemsFriendly[Input] {
def serializer: ArrayOfItemsSerDe[Input]
def sketch(maxMapSize: Int): ItemsSketch[Input]
}

object FrequentItemsFriendly {
implicit val stringIsFrequentItemsFriendly: FrequentItemsFriendly[String] = new FrequentItemsFriendly[String] {
override def serializer: ArrayOfItemsSerDe[String] = new ArrayOfStringsSerDe
override def sketch(maxMapSize: Int): ItemsSketch[String] = new StringItemsSketch(maxMapSize)
}

implicit val longIsCpcFriendly: FrequentItemsFriendly[java.lang.Long] = new FrequentItemsFriendly[java.lang.Long] {
override def serializer: ArrayOfItemsSerDe[java.lang.Long] = new ArrayOfLongsSerDe
}
implicit val doubleIsCpcFriendly: FrequentItemsFriendly[java.lang.Double] =
implicit val longIsFrequentItemsFriendly: FrequentItemsFriendly[java.lang.Long] =
new FrequentItemsFriendly[java.lang.Long] {
override def serializer: ArrayOfItemsSerDe[java.lang.Long] = new ArrayOfLongsSerDe
override def sketch(maxMapSize: Int): ItemsSketch[lang.Long] = new LongItemsSketch(maxMapSize)
}

implicit val doubleIsFrequentItemsFriendly: FrequentItemsFriendly[java.lang.Double] =
new FrequentItemsFriendly[java.lang.Double] {
override def serializer: ArrayOfItemsSerDe[java.lang.Double] = new ArrayOfDoublesSerDe
override def sketch(maxMapSize: Int): ItemsSketch[lang.Double] = new DoubleItemsSketch(maxMapSize)
}

implicit val BinaryIsCpcFriendly: FrequentItemsFriendly[Number] = new FrequentItemsFriendly[Number] {
override def serializer: ArrayOfItemsSerDe[Number] = new ArrayOfNumbersSerDe
}
}

class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: ErrorType = ErrorType.NO_FALSE_POSITIVES)
Expand All @@ -372,7 +373,12 @@ class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: E
override def irType: DataType = BinaryType
type Sketch = ItemsSketch[T]
override def prepare(input: T): Sketch = {
val sketch = new ItemsSketch[T](mapSize)
// The ItemsSketch implementation requires a size with a positive power of 2
// Initialize the sketch with the next closest power of 2
val power = math.ceil(math.log(math.max(mapSize, 1)) / math.log(2))
val sketchSize = math.pow(2, power).toInt

val sketch = implicitly[FrequentItemsFriendly[T]].sketch(sketchSize)
sketch.update(input)
sketch
}
Expand All @@ -393,8 +399,26 @@ class FrequentItems[T: FrequentItemsFriendly](val mapSize: Int, val errorType: E
ItemsSketch.getInstance[T](Memory.wrap(bytes), serDe)
}

override def finalize(ir: Sketch): Map[T, Long] =
ir.getFrequentItems(errorType).map(sk => sk.getItem -> sk.getEstimate).toMap
override def finalize(ir: Sketch): Map[T, Long] = {
if (mapSize == 0) {
return Map.empty
}

val items = ir.getFrequentItems(errorType).map(sk => sk.getItem -> sk.getEstimate)
val heap = mutable.PriorityQueue[(T, Long)]()(Ordering.by(_._2))

items.foreach({
case (key, value) =>
if (heap.size < mapSize) {
heap.enqueue((key, value))
} else if (heap.head._2 < value) {
heap.dequeue()
heap.enqueue((key, value))
}
})

heap.dequeueAll.toMap
}

override def normalize(ir: Sketch): Array[Byte] = {
val serDe = implicitly[FrequentItemsFriendly[T]].serializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ object ColumnAggregator {
private def toFloat[A: Numeric](inp: Any): Float = implicitly[Numeric[A]].toFloat(inp.asInstanceOf[A])
private def toLong[A: Numeric](inp: Any) = implicitly[Numeric[A]].toLong(inp.asInstanceOf[A])
private def boolToLong(inp: Any): Long = if (inp.asInstanceOf[Boolean]) 1 else 0
private def toJavaLong[A: Numeric](inp: Any) =
implicitly[Numeric[A]].toLong(inp.asInstanceOf[A]).asInstanceOf[java.lang.Long]
private def toJavaDouble[A: Numeric](inp: Any) =
implicitly[Numeric[A]].toDouble(inp.asInstanceOf[A]).asInstanceOf[java.lang.Double]

def construct(baseInputType: DataType,
aggregationPart: AggregationPart,
Expand Down Expand Up @@ -256,6 +260,17 @@ object ColumnAggregator {
aggregationPart.operation match {
case Operation.COUNT => simple(new Count)
case Operation.HISTOGRAM => simple(new Histogram(aggregationPart.getInt("k", Some(0))))
case Operation.APPROX_HISTOGRAM_K =>
val k = aggregationPart.getInt("k", Some(8))
inputType match {
case IntType => simple(new FrequentItems[java.lang.Long](k), toJavaLong[Int])
case LongType => simple(new FrequentItems[java.lang.Long](k))
case ShortType => simple(new FrequentItems[java.lang.Long](k), toJavaLong[Short])
case DoubleType => simple(new FrequentItems[java.lang.Double](k))
case FloatType => simple(new FrequentItems[java.lang.Double](k), toJavaDouble[Float])
case StringType => simple(new FrequentItems[String](k))
case _ => mismatchException
}
case Operation.SUM =>
inputType match {
case IntType => simple(new Sum[Long](LongType), toLong[Int])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package ai.chronon.aggregator.test

import ai.chronon.aggregator.base.FrequentItems
import junit.framework.TestCase

class FrequentItemsTest extends TestCase {
def testNonPowerOfTwoAndTruncate(): Unit = {
val size = 3
val items = new FrequentItems[String](size)
val ir = items.prepare("4")

def update(value: String, times: Int): Unit = (1 to times).foreach({ _ => items.update(ir, value) })

update("4", 3)
update("3", 3)
update("2", 2)
update("1", 1)

val result = items.finalize(ir)

assert(Map(
"4" -> 4,
"3" -> 3,
"2" -> 2
) == result)
}

def testLessItemsThanSize(): Unit = {
val size = 10
val items = new FrequentItems[java.lang.Long](size)
val ir = items.prepare(3)

def update(value: Long, times: Int): Unit = (1 to times).foreach({ _ => items.update(ir, value) })

update(3, 2)
update(2, 2)
update(1, 1)

val result = items.finalize(ir)

assert(Map(
3 -> 3,
2 -> 2,
1 -> 1
) == result)
}

def testZeroSize(): Unit = {
val size = 0
val items = new FrequentItems[java.lang.Double](size)
val ir = items.prepare(3.0)

def update(value: java.lang.Double, times: Int): Unit = (1 to times).foreach({ _ => items.update(ir, value) })

update(3.0, 2)
update(2.0, 2)
update(1.0, 1)

val result = items.finalize(ir)

assert(Map() == result)
}
}
2 changes: 2 additions & 0 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class Operation:
HISTOGRAM = ttypes.Operation.HISTOGRAM
# k truncates the map to top_k most frequent items, 0 turns off truncation
HISTOGRAM_K = collector(ttypes.Operation.HISTOGRAM)
# k truncates the map to top_k most frequent items, k is required and results are bounded
APPROX_HISTOGRAM_K = collector(ttypes.Operation.APPROX_HISTOGRAM_K)
FIRST_K = collector(ttypes.Operation.FIRST_K)
LAST_K = collector(ttypes.Operation.LAST_K)
TOP_K = collector(ttypes.Operation.TOP_K)
Expand Down
3 changes: 2 additions & 1 deletion api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ enum Operation {
TOP_K = 15,
BOTTOM_K = 16

HISTOGRAM = 17 // use this only if you know the set of inputs is bounded
HISTOGRAM = 17, // use this only if you know the set of inputs is bounded
APPROX_HISTOGRAM_K = 18
}

// integers map to milliseconds in the timeunit
Expand Down
29 changes: 15 additions & 14 deletions spark/src/main/scala/ai/chronon/spark/ChrononKryoRegistrator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.chronon.spark

import ai.chronon.aggregator.base.{DoubleItemsSketch, FrequentItemsFriendly, LongItemsSketch, StringItemsSketch}
import ai.chronon.aggregator.base.FrequentItemsFriendly._
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.yahoo.memory.Memory
import com.yahoo.sketches.ArrayOfStringsSerDe
import com.yahoo.sketches.cpc.CpcSketch
import com.yahoo.sketches.frequencies.ItemsSketch
import org.apache.spark.serializer.KryoRegistrator

import scala.reflect.runtime.universe._
import org.apache.spark.SPARK_VERSION

class CpcSketchKryoSerializer extends Serializer[CpcSketch] {
Expand All @@ -38,22 +40,19 @@ class CpcSketchKryoSerializer extends Serializer[CpcSketch] {
CpcSketch.heapify(bytes)
}
}

class ItemsSketchKryoSerializer extends Serializer[ItemSketchSerializable] {
override def write(kryo: Kryo, output: Output, sketch: ItemSketchSerializable): Unit = {
val serDe = new ArrayOfStringsSerDe
val bytes = sketch.sketch.toByteArray(serDe)
class ItemsSketchKryoSerializer[T: FrequentItemsFriendly: TypeTag] extends Serializer[ItemsSketch[T]] {
override def write(kryo: Kryo, output: Output, sketch: ItemsSketch[T]): Unit = {
val serializer = implicitly[FrequentItemsFriendly[T]].serializer
val bytes = sketch.toByteArray(serializer)
output.writeInt(bytes.size)
output.writeBytes(bytes)
}

override def read(kryo: Kryo, input: Input, `type`: Class[ItemSketchSerializable]): ItemSketchSerializable = {
override def read(kryo: Kryo, input: Input, `type`: Class[ItemsSketch[T]]): ItemsSketch[T] = {
val size = input.readInt()
val bytes = input.readBytes(size)
val serDe = new ArrayOfStringsSerDe
val result = new ItemSketchSerializable
result.sketch = ItemsSketch.getInstance[String](Memory.wrap(bytes), serDe)
result
val serializer = implicitly[FrequentItemsFriendly[T]].serializer
ItemsSketch.getInstance[T](Memory.wrap(bytes), serializer)
}
}

Expand Down Expand Up @@ -127,7 +126,6 @@ class ChrononKryoRegistrator extends KryoRegistrator {
"scala.reflect.ManifestFactory$LongManifest",
"org.apache.spark.sql.execution.joins.EmptyHashedRelation$",
"scala.reflect.ManifestFactory$$anon$1",
"scala.reflect.ClassTag$GenericClassTag",
"org.apache.spark.sql.execution.datasources.InMemoryFileIndex$SerializableFileStatus",
"org.apache.spark.sql.execution.datasources.InMemoryFileIndex$SerializableBlockLocation",
"scala.reflect.ManifestFactory$$anon$10",
Expand All @@ -144,10 +142,13 @@ class ChrononKryoRegistrator extends KryoRegistrator {
case _: ClassNotFoundException => // do nothing
}
}

kryo.register(classOf[Array[Array[Array[AnyRef]]]])
kryo.register(classOf[Array[Array[AnyRef]]])
kryo.register(classOf[CpcSketch], new CpcSketchKryoSerializer())
kryo.register(classOf[ItemSketchSerializable], new ItemsSketchKryoSerializer())
kryo.register(classOf[StringItemsSketch], new ItemsSketchKryoSerializer[String])
kryo.register(classOf[LongItemsSketch], new ItemsSketchKryoSerializer[java.lang.Long])
kryo.register(classOf[DoubleItemsSketch], new ItemsSketchKryoSerializer[java.lang.Double])
kryo.register(classOf[Array[ItemSketchSerializable]])
}
}
56 changes: 51 additions & 5 deletions spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,12 @@ class GroupByTest {
new Analyzer(tableUtils, groupByConf, endPartition, today).analyzeGroupBy(groupByConf, enableHitter = false)

print(aggregationsMetadata)
assertTrue(aggregationsMetadata.length == 1)
assertEquals(aggregationsMetadata(0).name, "time_spent_ms")
assertEquals(aggregationsMetadata(0).columnType, LongType)
assertTrue(aggregationsMetadata.length == 2)
val columns = aggregationsMetadata.map(a => a.name -> a.columnType).toMap
assertEquals(Map(
"time_spent_ms" -> LongType,
"price" -> DoubleType
), columns)
}

// test that OrderByLimit and OrderByLimitTimed serialization works well with Spark's data type
Expand Down Expand Up @@ -401,7 +404,8 @@ class GroupByTest {
val sourceSchema = List(
Column("user", StringType, 10000),
Column("item", StringType, 100),
Column("time_spent_ms", LongType, 5000)
Column("time_spent_ms", LongType, 5000),
Column("price", DoubleType, 100)
)
val namespace = "chronon_test"
val sourceTable = s"$namespace.test_group_by_steps$suffix"
Expand All @@ -410,7 +414,7 @@ class GroupByTest {
DataFrameGen.events(spark, sourceSchema, count = 1000, partitions = 200).save(sourceTable)
val source = Builders.Source.events(
query =
Builders.Query(selects = Builders.Selects("ts", "item", "time_spent_ms"), startPartition = startPartition),
Builders.Query(selects = Builders.Selects("ts", "item", "time_spent_ms", "price"), startPartition = startPartition),
table = sourceTable
)
(source, endPartition)
Expand Down Expand Up @@ -484,6 +488,48 @@ class GroupByTest {
additionalAgg = aggs)
}

@Test
def testApproxHistograms(): Unit = {
val (source, endPartition) = createTestSource(suffix = "_approx_histogram")
val tableUtils = TableUtils(spark)
val namespace = "test_approx_histograms"
val aggs = Seq(
Builders.Aggregation(
operation = Operation.APPROX_HISTOGRAM_K,
inputColumn = "item",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "4")
),
Builders.Aggregation(
operation = Operation.APPROX_HISTOGRAM_K,
inputColumn = "ts",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "4")
),
Builders.Aggregation(
operation = Operation.APPROX_HISTOGRAM_K,
inputColumn = "price",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "4")
),
)
backfill(name = "unit_test_group_by_approx_histograms",
source = source,
endPartition = endPartition,
namespace = namespace,
tableUtils = tableUtils,
additionalAgg = aggs)
}

@Test
def testReplaceJoinSource(): Unit = {
val namespace = "replace_join_source_ns"
Expand Down

0 comments on commit c7769f6

Please sign in to comment.