Skip to content

Commit

Permalink
Merge pull request #110 from williamboxhall/williamboxhall/polymorphi…
Browse files Browse the repository at this point in the history
…c_open_support

Polymorphic open support (Take 2)
  • Loading branch information
thake committed Sep 7, 2021
2 parents 4e7eb4c + 4c35b57 commit 0b587f3
Show file tree
Hide file tree
Showing 17 changed files with 282 additions and 99 deletions.
31 changes: 20 additions & 11 deletions src/main/kotlin/com/github/avrokotlin/avro4k/SerialDescriptor.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
package com.github.avrokotlin.avro4k

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.descriptors.PolymorphicKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.elementDescriptors
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.modules.SerializersModule

@ExperimentalSerializationApi
fun SerialDescriptor.leavesOfSealedClasses() : List<SerialDescriptor> {
return if (this.kind == PolymorphicKind.SEALED) {
elementDescriptors.filter {it.kind == SerialKind.CONTEXTUAL }.flatMap { it.elementDescriptors }.flatMap { it.leavesOfSealedClasses() }
} else {
listOf(this)
}
}
fun SerialDescriptor.leavesOfSealedClasses(): List<SerialDescriptor> {
return if (this.kind == PolymorphicKind.SEALED) {
elementDescriptors.filter { it.kind == SerialKind.CONTEXTUAL }.flatMap { it.elementDescriptors }
.flatMap { it.leavesOfSealedClasses() }
} else {
listOf(this)
}
}

@ExperimentalSerializationApi
fun SerialDescriptor.possibleSerializationSubclasses(serializersModule: SerializersModule): List<SerialDescriptor> {
return when (this.kind) {
PolymorphicKind.SEALED -> this.leavesOfSealedClasses()
PolymorphicKind.OPEN -> serializersModule.getPolymorphicDescriptors(this).sortedBy { it.serialName }
else -> throw UnsupportedOperationException("Can't get possible serialization subclasses for the SerialDescriptor of kind ${this.kind}.")
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ListDecoder(
StructureKind.CLASS -> RecordDecoder(descriptor, array[index] as GenericRecord, serializersModule, configuration)
StructureKind.LIST -> ListDecoder(schema.elementType, array[index] as GenericArray<*>, serializersModule, configuration)
StructureKind.MAP -> MapDecoder(descriptor, schema.elementType, array[index] as Map<String, *>, serializersModule, configuration)
PolymorphicKind.SEALED -> SealedClassDecoder(descriptor,array[index] as GenericRecord, serializersModule, configuration)
PolymorphicKind.SEALED -> UnionDecoder(descriptor,array[index] as GenericRecord, serializersModule, configuration)
else -> throw UnsupportedOperationException("Kind ${descriptor.kind} is currently not supported.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class RecordDecoder(
}
decoder
}
PolymorphicKind.SEALED -> SealedClassDecoder(descriptor,value as GenericRecord, serializersModule, configuration)
PolymorphicKind.SEALED, PolymorphicKind.OPEN -> UnionDecoder(descriptor,value as GenericRecord, serializersModule, configuration)
else -> throw UnsupportedOperationException("Decoding descriptor of kind ${descriptor.kind} is currently not supported")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class RootRecordDecoder(
serializersModule,
configuration
)
PolymorphicKind.SEALED -> SealedClassDecoder(descriptor, record, serializersModule, configuration)
PolymorphicKind.SEALED -> UnionDecoder(descriptor, record, serializersModule, configuration)
else -> throw SerializationException("Non-class structure passed to root record decoder")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.github.avrokotlin.avro4k.decoder

import com.github.avrokotlin.avro4k.AvroConfiguration
import com.github.avrokotlin.avro4k.RecordNaming
import com.github.avrokotlin.avro4k.leavesOfSealedClasses
import com.github.avrokotlin.avro4k.possibleSerializationSubclasses
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerializationException
Expand All @@ -14,10 +14,10 @@ import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord

@ExperimentalSerializationApi
class SealedClassDecoder (descriptor: SerialDescriptor,
private val value: GenericRecord,
override val serializersModule: SerializersModule,
private val configuration: AvroConfiguration
class UnionDecoder (descriptor: SerialDescriptor,
private val value: GenericRecord,
override val serializersModule: SerializersModule,
private val configuration: AvroConfiguration
) : AbstractDecoder(), FieldDecoder
{
private enum class DecoderState(val index : Int){
Expand All @@ -28,7 +28,7 @@ class SealedClassDecoder (descriptor: SerialDescriptor,
}
private var currentState = DecoderState.BEFORE

var leafDescriptor : SerialDescriptor = descriptor.leavesOfSealedClasses().firstOrNull {
private var leafDescriptor : SerialDescriptor = descriptor.possibleSerializationSubclasses(serializersModule).firstOrNull {
val schemaName = RecordNaming(value.schema.fullName, emptyList())
val serialName = RecordNaming(it)
serialName == schemaName
Expand All @@ -54,5 +54,5 @@ class SealedClassDecoder (descriptor: SerialDescriptor,
return recordDecoder.decodeSerializableValue(deserializer)
}

override fun decodeAny(): Any? = UnsupportedOperationException()
override fun decodeAny(): Any = UnsupportedOperationException()
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ interface StructureEncoder : FieldEncoder {
}
StructureKind.CLASS -> RecordEncoder(fieldSchema(), serializersModule) { addValue(it) }
StructureKind.MAP -> MapEncoder(fieldSchema(), serializersModule) { addValue(it) }
PolymorphicKind.SEALED -> SealedClassEncoder(fieldSchema(), serializersModule) { addValue(it) }
is PolymorphicKind -> UnionEncoder(fieldSchema(), serializersModule) { addValue(it) }
else -> throw SerializationException(".beginStructure was called on a non-structure type [$descriptor]")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class RootRecordEncoder(private val schema: Schema,
override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder {
return when (descriptor.kind) {
is StructureKind.CLASS -> RecordEncoder(schema, serializersModule, callback)
is PolymorphicKind.SEALED -> SealedClassEncoder(schema,serializersModule,callback)
is PolymorphicKind -> UnionEncoder(schema,serializersModule,callback)
else -> throw SerializationException("Unsupported root element passed to root record encoder")
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import org.apache.avro.Schema
class UnionEncoder(private val unionSchema : Schema,
override val serializersModule: SerializersModule,
private val callback: (Record) -> Unit) : AbstractEncoder() {

override fun encodeString(value: String){
//No need to encode the name of the concrete type. The name will never be encoded in the avro schema.
}
override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder {
return when (descriptor.kind) {
is StructureKind.CLASS, is StructureKind.OBJECT -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ fun schemaFor(serializersModule: SerializersModule,
namingStrategy,
resolvedSchemas
)
PolymorphicKind.SEALED -> SealedClassSchemaFor(descriptor, namingStrategy, serializersModule, resolvedSchemas)
StructureKind.CLASS, StructureKind.OBJECT -> when (descriptor.serialName) {
"kotlin.Pair" -> PairSchemaFor(descriptor, namingStrategy, serializersModule, resolvedSchemas)
else -> ClassSchemaFor(descriptor, namingStrategy, serializersModule, resolvedSchemas)
}
StructureKind.LIST -> ListSchemaFor(descriptor, serializersModule, namingStrategy, resolvedSchemas)
StructureKind.MAP -> MapSchemaFor(descriptor, serializersModule, namingStrategy, resolvedSchemas)
is PolymorphicKind -> UnionSchemaFor(descriptor, namingStrategy, serializersModule, resolvedSchemas)
else -> throw SerializationException("Unsupported type ${descriptor.serialName} of ${descriptor.kind}")
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
package com.github.avrokotlin.avro4k.schema

import com.github.avrokotlin.avro4k.RecordNaming
import com.github.avrokotlin.avro4k.leavesOfSealedClasses
import com.github.avrokotlin.avro4k.possibleSerializationSubclasses
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

@ExperimentalSerializationApi
class SealedClassSchemaFor(private val descriptor: SerialDescriptor,
private val namingStrategy: NamingStrategy,
private val serializersModule: SerializersModule,
private val resolvedSchemas: MutableMap<RecordNaming, Schema>
class UnionSchemaFor(private val descriptor: SerialDescriptor,
private val namingStrategy: NamingStrategy,
private val serializersModule: SerializersModule,
private val resolvedSchemas: MutableMap<RecordNaming, Schema>
) : SchemaFor {
override fun schema(): Schema {
val leafSerialDescriptors = descriptor.leavesOfSealedClasses()
val leafSerialDescriptors = descriptor.possibleSerializationSubclasses(serializersModule)
return Schema.createUnion(
leafSerialDescriptors.map { ClassSchemaFor(it,namingStrategy,serializersModule, resolvedSchemas).schema() }
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.github.avrokotlin.avro4k.io

import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.schema.ReferencingPolymorphicRoot
import com.github.avrokotlin.avro4k.schema.UnsealedChildOne
import com.github.avrokotlin.avro4k.schema.UnsealedChildTwo
import com.github.avrokotlin.avro4k.schema.UnsealedPolymorphicRoot
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic
import kotlinx.serialization.modules.subclass
import org.apache.avro.generic.GenericRecord

class PolymorphicClassIoTest : StringSpec({
"read / write nested polymorphic class" {
val module = SerializersModule {
polymorphic(UnsealedPolymorphicRoot::class) {
subclass(UnsealedChildOne::class)
subclass(UnsealedChildTwo::class)
}
}
val avro = Avro(serializersModule = module)
writeRead(ReferencingPolymorphicRoot(UnsealedChildOne("one")), ReferencingPolymorphicRoot.serializer(), avro)
writeRead(ReferencingPolymorphicRoot(UnsealedChildOne("one")), ReferencingPolymorphicRoot.serializer(), avro) {
val root = it["root"] as GenericRecord
root.schema shouldBe avro.schema(UnsealedChildOne.serializer())
}
}
})
90 changes: 45 additions & 45 deletions src/test/kotlin/com/github/avrokotlin/avro4k/io/StreamTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,75 +11,75 @@ import org.apache.avro.generic.GenericRecord
import org.apache.avro.io.DecoderFactory
import java.io.ByteArrayOutputStream

fun <T> writeRead(t: T, serializer: KSerializer<T>) {
writeData(t, serializer).apply {
val record = readData(this, serializer)
val tt = Avro.default.fromRecord(serializer, record)
fun <T> writeRead(t: T, serializer: KSerializer<T>, avro: Avro = Avro.default) {
writeData(t, serializer, avro).apply {
val record = readData(this, serializer, avro)
val tt = avro.fromRecord(serializer, record)
t shouldBe tt
}
writeBinary(t, serializer).apply {
val record = readBinary(this, serializer)
val tt = Avro.default.fromRecord(serializer, record)
writeBinary(t, serializer, avro).apply {
val record = readBinary(this, serializer, avro)
val tt = avro.fromRecord(serializer, record)
t shouldBe tt
}
writeJson(t, serializer).apply {
val record = readJson(this, serializer)
val tt = Avro.default.fromRecord(serializer, record)
writeJson(t, serializer, avro).apply {
val record = readJson(this, serializer, avro)
val tt = avro.fromRecord(serializer, record)
t shouldBe tt
}
}

fun <T> writeRead(t: T, expected: T, serializer: KSerializer<T>) {
writeData(t, serializer).apply {
val record = readData(this, serializer)
val tt = Avro.default.fromRecord(serializer, record)
fun <T> writeRead(t: T, expected: T, serializer: KSerializer<T>, avro: Avro = Avro.default) {
writeData(t, serializer, avro).apply {
val record = readData(this, serializer, avro)
val tt = avro.fromRecord(serializer, record)
tt shouldBe expected
}
writeBinary(t, serializer).apply {
val record = readBinary(this, serializer)
val tt = Avro.default.fromRecord(serializer, record)
writeBinary(t, serializer, avro).apply {
val record = readBinary(this, serializer, avro)
val tt = avro.fromRecord(serializer, record)
tt shouldBe expected
}
}

fun <T> writeRead(t: T, serializer: KSerializer<T>, test: (GenericRecord) -> Any) {
writeData(t, serializer).apply {
val record = readData(this, serializer)
fun <T> writeRead(t: T, serializer: KSerializer<T>, avro: Avro = Avro.default, test: (GenericRecord) -> Any) {
writeData(t, serializer, avro).apply {
val record = readData(this, serializer, avro)
test(record)
}
writeBinary(t, serializer).apply {
val record = readBinary(this, serializer)
writeBinary(t, serializer, avro).apply {
val record = readBinary(this, serializer, avro)
test(record)
}
writeJson(t, serializer).apply {
val record = readJson(this, serializer)
writeJson(t, serializer, avro).apply {
val record = readJson(this, serializer, avro)
test(record)
}
}

fun <T> writeData(t: T, serializer: SerializationStrategy<T>): ByteArray {
val schema = Avro.default.schema(serializer)
fun <T> writeData(t: T, serializer: SerializationStrategy<T>, avro: Avro = Avro.default): ByteArray {
val schema = avro.schema(serializer)
val out = ByteArrayOutputStream()
val avro = Avro.default.openOutputStream(serializer) {
val output = avro.openOutputStream(serializer) {
encodeFormat = AvroEncodeFormat.Data()
this.schema = schema
}.to(out)
avro.write(t)
avro.close()
output.write(t)
output.close()
return out.toByteArray()
}

fun <T> readJson(bytes: ByteArray, serializer: KSerializer<T>): GenericRecord {
val schema = Avro.default.schema(serializer)
fun <T> readJson(bytes: ByteArray, serializer: KSerializer<T>, avro: Avro = Avro.default): GenericRecord {
val schema = avro.schema(serializer)
val datumReader = GenericDatumReader<GenericRecord>(schema)
val decoder = DecoderFactory.get().jsonDecoder(schema, SeekableByteArrayInput(bytes))
return datumReader.read(null, decoder)
}

fun <T> writeJson(t: T, serializer: KSerializer<T>): ByteArray {
val schema = Avro.default.schema(serializer)
fun <T> writeJson(t: T, serializer: KSerializer<T>, avro: Avro = Avro.default): ByteArray {
val schema = avro.schema(serializer)
val baos = ByteArrayOutputStream()
val output = Avro.default.openOutputStream(serializer) {
val output = avro.openOutputStream(serializer) {
encodeFormat = AvroEncodeFormat.Json
this.schema = schema
}.to(baos)
Expand All @@ -88,28 +88,28 @@ fun <T> writeJson(t: T, serializer: KSerializer<T>): ByteArray {
return baos.toByteArray()
}

fun <T> readData(bytes: ByteArray, serializer: KSerializer<T>): GenericRecord {
val schema = Avro.default.schema(serializer)
val avro = Avro.default.openInputStream {
fun <T> readData(bytes: ByteArray, serializer: KSerializer<T>, avro: Avro = Avro.default): GenericRecord {
val schema = avro.schema(serializer)
val input = avro.openInputStream {
decodeFormat = AvroDecodeFormat.Data(schema)
}.from(bytes)
return avro.next() as GenericRecord
return input.next() as GenericRecord
}

fun <T> writeBinary(t: T, serializer: SerializationStrategy<T>): ByteArray {
val schema = Avro.default.schema(serializer)
fun <T> writeBinary(t: T, serializer: SerializationStrategy<T>, avro: Avro = Avro.default): ByteArray {
val schema = avro.schema(serializer)
val out = ByteArrayOutputStream()
val avro = Avro.default.openOutputStream(serializer) {
val output = avro.openOutputStream(serializer) {
encodeFormat = AvroEncodeFormat.Binary
this.schema = schema
}.to(out)
avro.write(t)
avro.close()
output.write(t)
output.close()
return out.toByteArray()
}

fun <T> readBinary(bytes: ByteArray, serializer: KSerializer<T>): GenericRecord {
val schema = Avro.default.schema(serializer)
fun <T> readBinary(bytes: ByteArray, serializer: KSerializer<T>, avro: Avro = Avro.default): GenericRecord {
val schema = avro.schema(serializer)
val datumReader = GenericDatumReader<GenericRecord>(schema)
val decoder = DecoderFactory.get().binaryDecoder(SeekableByteArrayInput(bytes), null)
return datumReader.read(null, decoder)
Expand Down

0 comments on commit 0b587f3

Please sign in to comment.