Skip to content

Commit

Permalink
Added array encoding and strict mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sksamuel committed Apr 28, 2024
1 parent 6de95d4 commit b61b1bc
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ fun interface Decoder<T> {
Int::class -> if (useStrictPrimitiveDecoders) StrictIntDecoder else IntDecoder
Long::class -> if (useStrictPrimitiveDecoders) StrictLongDecoder else LongDecoder
List::class -> ListDecoder(decoderFor(type.arguments.first().type!!))
LongArray::class -> LongArrayDecoder(if (useStrictPrimitiveDecoders) StrictLongDecoder else LongDecoder)
IntArray::class -> IntArrayDecoder(if (useStrictPrimitiveDecoders) StrictIntDecoder else IntDecoder)
Set::class -> SetDecoder(decoderFor(type.arguments.first().type!!))
Map::class -> MapDecoder(decoderFor(type.arguments[1].type!!))
LocalTime::class -> LocalTimeDecoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,36 @@ package com.sksamuel.centurion.avro.decoders
import org.apache.avro.Schema
import org.apache.avro.generic.GenericData

class IntArrayDecoder(private val decoder: Decoder<Int>) : Decoder<IntArray> {
override fun decode(schema: Schema): (Any?) -> IntArray {
require(schema.type == Schema.Type.ARRAY)
val decode = decoder.decode(schema.elementType)
return { value ->
when (value) {
is GenericData.Array<*> -> value.map { decode.invoke(it) }.toTypedArray().toIntArray()
is List<*> -> value.map { decode.invoke(it) }.toTypedArray().toIntArray()
is Array<*> -> value.map { decode.invoke(it) }.toTypedArray().toIntArray()
else -> error("Unsupported list type $value")
}
}
}
}

class LongArrayDecoder(private val decoder: Decoder<Long>) : Decoder<LongArray> {
override fun decode(schema: Schema): (Any?) -> LongArray {
require(schema.type == Schema.Type.ARRAY)
val decode = decoder.decode(schema.elementType)
return { value ->
when (value) {
is GenericData.Array<*> -> value.map { decode.invoke(it) }.toTypedArray().toLongArray()
is List<*> -> value.map { decode.invoke(it) }.toTypedArray().toLongArray()
is Array<*> -> value.map { decode.invoke(it) }.toTypedArray().toLongArray()
else -> error("Unsupported list type $value")
}
}
}
}

class ListDecoder<T>(private val decoder: Decoder<T>) : Decoder<List<T>> {
override fun decode(schema: Schema): (Any?) -> List<T> {
require(schema.type == Schema.Type.ARRAY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ fun interface Encoder<T> {
BigDecimal::class -> BigDecimalStringEncoder
Set::class -> SetEncoder(encoderFor(type.arguments.first().type!!))
List::class -> ListEncoder(encoderFor(type.arguments.first().type!!))
LongArray::class -> LongArrayEncoder(LongEncoder)
IntArray::class -> IntArrayEncoder(IntEncoder)
Map::class -> MapEncoder(
if (globalUseJavaString) JavaStringEncoder else StringEncoder,
encoderFor(type.arguments[1].type!!)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.apache.avro.Schema
import org.apache.avro.generic.GenericArray

/**
* An [Encoder] for Arrays of [T] that encodes into an Avro [GenericArray].
* An [Encoder] for Arrays of [T].
*/
class ArrayEncoder<T>(private val encoder: Encoder<T>) : Encoder<Array<T>> {
override fun encode(schema: Schema): (Array<T>) -> Any? {
Expand All @@ -18,6 +18,34 @@ class ArrayEncoder<T>(private val encoder: Encoder<T>) : Encoder<Array<T>> {
}
}

/**
* An [Encoder] for LongArrays.
*/
class LongArrayEncoder(private val encoder: Encoder<Long>) : Encoder<LongArray> {
override fun encode(schema: Schema): (LongArray) -> Any? {
require(schema.type == Schema.Type.ARRAY)
val elements = encoder.encode(schema.elementType)
return { value ->
if (value.isEmpty()) emptyList()
else value.map { elements.invoke(it) }
}
}
}

/**
* An [Encoder] for IntArrays.
*/
class IntArrayEncoder(private val encoder: Encoder<Int>) : Encoder<IntArray> {
override fun encode(schema: Schema): (IntArray) -> Any? {
require(schema.type == Schema.Type.ARRAY)
val elements = encoder.encode(schema.elementType)
return { value ->
if (value.isEmpty()) emptyList()
else value.map { elements.invoke(it) }
}
}
}

/**
* An [Encoder] for Lists of [T] that encodes into an Avro [GenericArray].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class ReflectionSchemaBuilder(
Float::class -> builder.floatType()
Set::class -> builder.array().items(schemaFor(type.arguments.first().type!!))
List::class -> builder.array().items(schemaFor(type.arguments.first().type!!))
Array::class -> builder.array().items(schemaFor(type.arguments.first().type!!))
LongArray::class -> builder.array().items(Schema.create(Schema.Type.LONG))
IntArray::class -> builder.array().items(Schema.create(Schema.Type.INT))
Map::class -> builder.map().values(schemaFor(type.arguments[1].type!!))
is KClass<*> -> if (classifier.java.isEnum)
builder
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.sksamuel.centurion.avro

import com.sksamuel.centurion.avro.decoders.SpecificRecordDecoder
import com.sksamuel.centurion.avro.encoders.SpecificRecordEncoder
import com.sksamuel.centurion.avro.encoders.Wine
import com.sksamuel.centurion.avro.generation.ReflectionSchemaBuilder
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe

class RoundTripTest : FunSpec() {
init {
test("round trip encode / decode") {
val schema = ReflectionSchemaBuilder().schema(RoundTrip::class)
val rt = RoundTrip(
s = null,
b = false,
l = 1436,
d = 4.5,
i = 7799,
f = 6.7f,
sets = setOf("foo", "bar"),
lists = listOf(6, 7),
arrays = longArrayOf(6L, 7L),
maps = mapOf(),
wine = Wine.Shiraz,
)
val actual = SpecificRecordDecoder(RoundTrip::class).decode(schema)
.invoke(SpecificRecordEncoder(RoundTrip::class).encode(schema).invoke(rt))
actual.s shouldBe actual.s
actual.b shouldBe actual.b
actual.l shouldBe actual.l
actual.d shouldBe actual.d
actual.i shouldBe actual.i
actual.f shouldBe actual.f
actual.sets shouldBe actual.sets
actual.lists shouldBe actual.lists
actual.maps shouldBe actual.maps
actual.wine shouldBe actual.wine
actual.arrays shouldBe actual.arrays
}
}
}

data class RoundTrip(
val s: String?,
val b: Boolean,
val l: Long,
val d: Double,
val i: Int,
val f: Float,
val sets: Set<String>,
val lists: List<Int>,
val arrays: LongArray,
val maps: Map<String, Double>,
val wine: Wine?,
)

0 comments on commit b61b1bc

Please sign in to comment.