Skip to content

Commit

Permalink
Added explicit sealed subtypes (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
sksamuel committed May 27, 2023
1 parent 3a8259f commit 2e60d25
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 42 deletions.
17 changes: 13 additions & 4 deletions hoplite-core/src/main/kotlin/com/sksamuel/hoplite/ConfigFailure.kt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ sealed interface ConfigFailure {
override fun description(): String = "Data class ${kclass.qualifiedName} has no constructors"
}

data class NoSuchSealedSubtype(val kclass: KClass<*>, val value: String) : ConfigFailure {
override fun description(): String =
"No sealed subtype of `${kclass.jvmName}` was found using the discriminator value `$value`"
}

data class UnknownSource(val source: String) : ConfigFailure {
override fun description(): String = "Could not find $source"
}
Expand All @@ -114,10 +119,14 @@ sealed interface ConfigFailure {
override fun description(): String = message
}

data class NoSealedClassObjectSubtype(val type: KClass<*>, val node: StringNode) : ConfigFailure {
data class InvalidDiscriminatorField(val kclass: KClass<*>, val field: String) : ConfigFailure {
override fun description(): String = "Invalid discriminator field to select sealed subtype. Must specify `$field` to be a valid subtype of `${kclass.java.name}`."
}

data class NoSealedClassObjectSubtype(val kclass: KClass<*>, val subtypeName: String) : ConfigFailure {
override fun description(): String {
val subclasses = type.sealedSubclasses.joinToString(", ") { it.jvmName }
return "Could not find subclass of $type matching name ${node.value}: Tried $subclasses ${node.pos.loc()}"
val subclasses = kclass.sealedSubclasses.joinToString(", ") { it.java.simpleName }
return "Could not find subclass of $kclass matching name ${subtypeName}: Available $subclasses"
}
}

Expand All @@ -138,7 +147,7 @@ sealed interface ConfigFailure {
}

data class SealedClassWithoutImpls(val type: KClass<*>) : ConfigFailure {
override fun description(): String = "Sealed class $type does not define any subclasses"
override fun description(): String = "Sealed class `${type.jvmName}` does not define any subclasses"
}

data class SealedClassWithoutObject(val type: KClass<*>) : ConfigFailure {
Expand Down
42 changes: 28 additions & 14 deletions hoplite-core/src/main/kotlin/com/sksamuel/hoplite/ConfigLoader.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,23 @@ class ConfigLoader(
val secretsPolicy: SecretsPolicy? = null,
val environment: Environment? = null,
val obfuscator: Obfuscator? = null,
val reportPrintFn: Print? = null,
val reportPrintFn: Print = { println(it) },
val flattenArraysToString: Boolean = false,
val resolvers: List<Resolver> = emptyList()
val resolvers: List<Resolver> = emptyList(),
val sealedTypeDiscriminatorField: String? = null,
) {

init {
if (sealedTypeDiscriminatorField == null) {
reportPrintFn.invoke(
"Hoplite is configured to infer which sealed type to choose by inspecting the config values at runtime. " +
"This behaviour is now deprecated in favour of explicitly specifying the type through a discriminator field. " +
"In 3.0 this new behavior will become the default. " +
"To enable this behavior now (and disable this warning), invoke withExplicitSealedTypes() on the ConfigLoaderBuilder."
)
}
}

companion object {

/**
Expand Down Expand Up @@ -90,7 +102,7 @@ class ConfigLoader(
*/
inline fun <reified A : Any> loadConfigOrThrow(
resourceOrFiles: List<String>,
classpathResourceLoader: ClasspathResourceLoader = ConfigSource.Companion::class.java.toClasspathResourceLoader()
classpathResourceLoader: ClasspathResourceLoader = ConfigSource.Companion::class.java.toClasspathResourceLoader(),
): A = loadConfig<A>(resourceOrFiles, classpathResourceLoader).returnOrThrow()

/**
Expand All @@ -111,7 +123,7 @@ class ConfigLoader(
*/
inline fun <reified A : Any> loadConfig(
vararg resourceOrFiles: String,
classpathResourceLoader: ClasspathResourceLoader = ConfigSource.Companion::class.java.toClasspathResourceLoader()
classpathResourceLoader: ClasspathResourceLoader = ConfigSource.Companion::class.java.toClasspathResourceLoader(),
): ConfigResult<A> = loadConfig(resourceOrFiles.toList(), classpathResourceLoader)

/**
Expand All @@ -124,7 +136,7 @@ class ConfigLoader(
*/
inline fun <reified A : Any> loadConfig(
resourceOrFiles: List<String>,
classpathResourceLoader: ClasspathResourceLoader = Companion::class.java.toClasspathResourceLoader()
classpathResourceLoader: ClasspathResourceLoader = Companion::class.java.toClasspathResourceLoader(),
): ConfigResult<A> = loadConfig(A::class, emptyList(), resourceOrFiles, classpathResourceLoader)

/**
Expand Down Expand Up @@ -159,9 +171,10 @@ class ConfigLoader(
decodeMode = decodeMode,
useReport = useReport,
obfuscator = obfuscator ?: PrefixObfuscator(3),
reportPrintFn = reportPrintFn ?: { println(it) },
reportPrintFn = reportPrintFn,
environment = environment,
resolvers = resolvers
resolvers = resolvers,
sealedTypeDiscriminatorField = sealedTypeDiscriminatorField,
).decode(kclass, environment, resourceOrFiles, propertySources, configSources)
}

Expand Down Expand Up @@ -203,16 +216,17 @@ class ConfigLoader(
preprocessors = preprocessors,
preprocessingIterations = preprocessingIterations,
decoderRegistry = decoderRegistry,
paramMappers = paramMappers,
flattenArraysToString = false, // not needed to load nodes
allowUnresolvedSubstitutions = allowUnresolvedSubstitutions,
secretsPolicy = null, // not used when loading nodes
paramMappers = paramMappers, // not needed to load nodes
flattenArraysToString = false,
allowUnresolvedSubstitutions = allowUnresolvedSubstitutions, // not used when loading nodes
secretsPolicy = null, // not used when loading nodes
decodeMode = DecodeMode.Lenient, // not used when loading nodes
useReport = false, // not used when loading nodes
obfuscator = StrictObfuscator("*"), // not used when loading nodes
reportPrintFn = reportPrintFn ?: { }, // not used when loading nodes
obfuscator = StrictObfuscator("*"), // not used when loading nodes
reportPrintFn = reportPrintFn ?: { },
environment = environment,
resolvers = resolvers
resolvers = resolvers,
sealedTypeDiscriminatorField = sealedTypeDiscriminatorField,
).load(resourceOrFiles, propertySources, configSources)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ConfigLoaderBuilder private constructor() {
private var cascadeMode: CascadeMode = CascadeMode.Merge
private var allowEmptySources = false
private var allowUnresolvedSubstitutions = false
private var sealedTypeDiscriminatorField: String? = null
private var contextResolverMode = ContextResolverMode.Error

private val propertySources = mutableListOf<PropertySource>()
Expand Down Expand Up @@ -200,6 +201,7 @@ class ConfigLoaderBuilder private constructor() {
}
}

@Deprecated("Replaced with resolvers")
fun withPreprocessingIterations(iterations: Int): ConfigLoaderBuilder = apply {
preprocessingIterations = iterations
}
Expand Down Expand Up @@ -260,7 +262,7 @@ class ConfigLoaderBuilder private constructor() {
allowUnresolvedSubstitutions = true
}

fun withSubstitutionMode(mode: ContextResolverMode) = apply {
fun withContextResolverMode(mode: ContextResolverMode) = apply {
contextResolverMode = mode
}

Expand Down Expand Up @@ -308,6 +310,18 @@ class ConfigLoaderBuilder private constructor() {
)
fun report(reporter: Reporter) = apply { useReport = true }

/**
* Set a field name to be used as the discriminator field for sealed types.
*
* Then, Hoplite will use this field to pick amongst the sealed types instead of trying to
* infer the type from the available config values.
*
* This option will become the default in 3.0.
*/
@ExperimentalHoplite
fun withExplicitSealedTypes(discriminatorFieldName: String = "_type"): ConfigLoaderBuilder =
apply { sealedTypeDiscriminatorField = discriminatorFieldName }

fun build(): ConfigLoader {
return ConfigLoader(
decoderRegistry = DefaultDecoderRegistry(decoders),
Expand All @@ -327,7 +341,8 @@ class ConfigLoaderBuilder private constructor() {
environment = environment,
obfuscator = obfuscator,
reportPrintFn = reportPrintFn,
flattenArraysToString = flattenArraysToString
flattenArraysToString = flattenArraysToString,
sealedTypeDiscriminatorField = sealedTypeDiscriminatorField,
)
}
}
Expand All @@ -336,27 +351,27 @@ fun defaultPropertySources(): List<PropertySource> = listOfNotNull(
EnvironmentVariableOverridePropertySource(true),
SystemPropertiesPropertySource,
UserSettingsPropertySource,
XdgConfigPropertySource
XdgConfigPropertySource,
)

fun defaultPreprocessors(): List<Preprocessor> = listOf(
EnvOrSystemPropertyPreprocessor,
RandomPreprocessor,
LookupPreprocessor
LookupPreprocessor,
)

fun defaultResolvers(): List<Resolver> = listOf(
EnvVarContextResolver,
SystemPropertyContextResolver,
ReferenceContextResolver,
HopliteContextResolver
HopliteContextResolver,
)

fun defaultParamMappers(): List<ParameterMapper> = listOf(
DefaultParamMapper,
SnakeCaseParamMapper,
KebabCaseParamMapper,
AliasAnnotationParamMapper
AliasAnnotationParamMapper,
)

val defaultDecoders = listOf(
Expand Down Expand Up @@ -409,5 +424,5 @@ val defaultDecoders = listOf(
com.sksamuel.hoplite.decoder.SecondsDecoder(),
com.sksamuel.hoplite.decoder.InlineClassDecoder(),
com.sksamuel.hoplite.decoder.SealedClassDecoder(),
com.sksamuel.hoplite.decoder.DataClassDecoder()
com.sksamuel.hoplite.decoder.DataClassDecoder(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ data class DecoderContext(
val environment: Environment? = null,
val resolvers: Resolving = Resolving.empty,
// determines if we should error when a context resolver cannot find a substitution
val contextResolverMode: ContextResolverMode = ContextResolverMode.Error
val contextResolverMode: ContextResolverMode = ContextResolverMode.Error,
val sealedTypeDiscriminatorField: String? = null,
) {


/**
* Returns a [Decoder] for type [type].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import com.sksamuel.hoplite.DecoderContext
import com.sksamuel.hoplite.MapNode
import com.sksamuel.hoplite.Node
import com.sksamuel.hoplite.StringNode
import com.sksamuel.hoplite.fp.Validated
import com.sksamuel.hoplite.fp.invalid
import com.sksamuel.hoplite.fp.plus
import com.sksamuel.hoplite.fp.sequence
import com.sksamuel.hoplite.fp.valid
import com.sksamuel.hoplite.valueOrNull
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KType
Expand All @@ -27,21 +29,64 @@ class SealedClassDecoder : NullHandlingDecoder<Any> {
}
}

// it's common to have custom decoders for sealed classes, so sealed classes should be very low priority
// it's common to have custom decoders for sealed classes, so this decoder should be very low priority
override fun priority(): Int = Integer.MIN_VALUE + 100

override fun safeDecode(
node: Node,
type: KType,
context: DecoderContext
): ConfigResult<Any> {
// to determine which sealed class to use, we can just try each in turn until one results in success

val kclass = type.classifier as KClass<*>

// if we have no subclasses then that is an error of course
if (kclass.sealedSubclasses.isEmpty()) return ConfigFailure.SealedClassWithoutImpls(kclass).invalid()

return when (val field = context.sealedTypeDiscriminatorField) {
null -> deriveInstance(node, type, context)
else -> useDiscriminator(field, node, type, context)
}
}

private fun useDiscriminator(field: String, node: Node, type: KType, context: DecoderContext): ConfigResult<Any> {
val kclass = type.classifier as KClass<*>
val subclasses = kclass.sealedSubclasses

// when explicitly specifying subtypes, we must have a map type containing the disriminator field,
// or a string type referencing an object instance
return when (node) {
is StringNode -> {
val referencedName = node.value
val subtype = subclasses.find { it.java.simpleName == referencedName }?.objectInstance
subtype?.valid() ?: ConfigFailure.NoSealedClassObjectSubtype(kclass, referencedName).invalid()
}
is MapNode -> {
when (val discriminatorField = node[field]) {
is StringNode -> {
val subtype = subclasses.find { it.java.simpleName == discriminatorField.value }
if (subtype == null) {
ConfigFailure.NoSuchSealedSubtype(kclass, discriminatorField.value).invalid()
} else {
// check for object-ness first
subtype.objectInstance?.valid()
// now we know the type is not an object, we can use the data class decoder directly
?: DataClassDecoder().decode(node, subtype.createType(), context)
}
}
else -> ConfigFailure.InvalidDiscriminatorField(kclass, field).invalid()
}
}
else -> ConfigFailure.Generic("Sealed type values must be maps or strings").invalid()
}
}

// to determine which sealed class to use, we can just try each in turn until one results in success
private fun deriveInstance(node: Node, type: KType, context: DecoderContext): Validated<ConfigFailure, Any> {
val kclass = type.classifier as KClass<*>
val subclasses = kclass.sealedSubclasses

return when {
// if we have no subclasses then that is an error of course
subclasses.isEmpty() -> ConfigFailure.SealedClassWithoutImpls(kclass).invalid()
// if we have a map with no values then we can look for an object subclass,
// but only if there is a single object subclass, otherwise we don't know which one to pick
node is MapNode && node.size == 0 -> {
Expand All @@ -58,7 +103,7 @@ class SealedClassDecoder : NullHandlingDecoder<Any> {
// we can use the object directly
val error = if (node is StringNode) {
val obj = subclasses.find { it.simpleName == node.value }?.objectInstance
if (obj != null) return obj.valid() else ConfigFailure.NoSealedClassObjectSubtype(kclass, node)
if (obj != null) return obj.valid() else ConfigFailure.NoSealedClassObjectSubtype(kclass, node.value)
} else null

val results = kclass.sealedSubclasses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class ConfigParser(
private val useReport: Boolean,
private val obfuscator: Obfuscator,
private val reportPrintFn: Print,
private val environment: Environment?
private val environment: Environment?,
private val sealedTypeDiscriminatorField: String?,
) {

private val loader = PropertySourceLoader(classpathResourceLoader, parserRegistry, allowEmptyTree)
Expand All @@ -55,7 +56,8 @@ class ConfigParser(
paramMappers = paramMappers,
config = DecoderConfig(flattenArraysToString),
environment = environment,
resolvers = Resolving(resolvers, root)
resolvers = Resolving(resolvers, root),
sealedTypeDiscriminatorField = sealedTypeDiscriminatorField,
)
}

Expand All @@ -64,7 +66,7 @@ class ConfigParser(
environment: Environment?,
resourceOrFiles: List<String>,
propertySources: List<PropertySource>,
configSources: List<ConfigSource>
configSources: List<ConfigSource>,
): ConfigResult<A> {

if (decoderRegistry.size == 0)
Expand All @@ -79,11 +81,11 @@ class ConfigParser(
val decoded = decoding.decode(kclass, preprocessed, decodeMode, context)
val state = createDecodingState(preprocessed, context, secretsPolicy)

// always do report regardless of decoder result
if (useReport) {
Reporter(reportPrintFn, obfuscator, environment)
.printReport(propertySources, state, context.reports)
}
// always do report regardless of decoder result
if (useReport) {
Reporter(reportPrintFn, obfuscator, environment)
.printReport(propertySources, state, context.reports)
}

decoded
}
Expand All @@ -95,7 +97,7 @@ class ConfigParser(
fun load(
resourceOrFiles: List<String>,
propertySources: List<PropertySource>,
configSources: List<ConfigSource>
configSources: List<ConfigSource>,
): ConfigResult<Node> {
return loader.loadNodes(propertySources, configSources, resourceOrFiles).flatMap { nodes ->
cascader.cascade(nodes).flatMap { node ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class EmptyDecoderRegistryTest : FunSpec() {
init {
test("empty decoder registry throws error") {
data class Config(val a: String)

val parsers = defaultParserRegistry()
val sources = defaultPropertySources()
val preprocessors = defaultPreprocessors()
Expand All @@ -21,7 +22,8 @@ class EmptyDecoderRegistryTest : FunSpec() {
preprocessors,
mappers,
allowEmptyTree = false,
allowUnresolvedSubstitutions = false
allowUnresolvedSubstitutions = false,
sealedTypeDiscriminatorField = null,
).loadConfig<Config>()
e as Validated.Invalid<ConfigFailure>
e.error shouldBe ConfigFailure.EmptyDecoderRegistry
Expand Down

0 comments on commit 2e60d25

Please sign in to comment.