diff --git a/CHANGELOG.md b/CHANGELOG.md index 898a87d5b..a9cbed6b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# 0.18.11 + +* Added a utility method, `Schema.transformTransitivelyK`, to help in recursively transforming schemas. +In addition, the semantics of `transformHintsTransitively` have been changed: the transformation no longer modifies the hints on the result of the `total` function. + # 0.18.10 * Bumps alloy to 0.3.1. This is required as otherwise the `alloy#nullable` hints get filtered out when using SimpleRestJsonBuilder. diff --git a/modules/bootstrapped/test/src/smithy4s/schema/HintsTransformationSpec.scala b/modules/bootstrapped/test/src/smithy4s/schema/HintsTransformationSpec.scala index 3f794c312..f1ea8e3c5 100644 --- a/modules/bootstrapped/test/src/smithy4s/schema/HintsTransformationSpec.scala +++ b/modules/bootstrapped/test/src/smithy4s/schema/HintsTransformationSpec.scala @@ -214,7 +214,12 @@ class HintsTransformationSpec() extends FunSuite { values: List[EnumValue[E]], total: E => EnumValue[E] ): Count[E] = { e => - count(hints) + count(total(e).hints) + count(hints) + count( + values + .find(_.value == total(e).value) + .getOrElse(sys.error("Unknown enum value")) + .hints + ) } def struct[S]( diff --git a/modules/core/src/smithy4s/schema/Schema.scala b/modules/core/src/smithy4s/schema/Schema.scala index 7d8dc3ce0..142cd8301 100644 --- a/modules/core/src/smithy4s/schema/Schema.scala +++ b/modules/core/src/smithy4s/schema/Schema.scala @@ -66,18 +66,20 @@ sealed trait Schema[A]{ case s: OptionSchema[a] => OptionSchema(s.underlying.transformHintsLocally(f)).asInstanceOf[Schema[A]] } - final def transformHintsTransitively(f: Hints => Hints): Schema[A] = this match { - case PrimitiveSchema(shapeId, hints, tag) => PrimitiveSchema(shapeId, f(hints), tag) - case s: CollectionSchema[c, a] => CollectionSchema[c, a](s.shapeId, f(s.hints), s.tag, s.member.transformHintsTransitively(f)).asInstanceOf[Schema[A]] - case s: MapSchema[k, v] => MapSchema(s.shapeId, f(s.hints), s.key.transformHintsTransitively(f), s.value.transformHintsTransitively(f)).asInstanceOf[Schema[A]] - case EnumerationSchema(shapeId, hints, tag, values, total) => EnumerationSchema(shapeId, f(hints), tag, values.map(_.transformHints(f)), total andThen (_.transformHints(f))) - case StructSchema(shapeId, hints, fields, make) => StructSchema(shapeId, f(hints), fields.map(_.transformHintsTransitively(f)), make) - case UnionSchema(shapeId, hints, alternatives, dispatch) => UnionSchema(shapeId, f(hints), alternatives.map(_.transformHintsTransitively(f)), dispatch) - case BijectionSchema(schema, bijection) => BijectionSchema(schema.transformHintsTransitively(f), bijection) - case RefinementSchema(schema, refinement) => RefinementSchema(schema.transformHintsTransitively(f), refinement) - case LazySchema(suspend) => LazySchema(suspend.map(_.transformHintsTransitively(f))) - case s: OptionSchema[a] => OptionSchema(s.underlying.transformHintsTransitively(f)).asInstanceOf[Schema[A]] - } + final def transformHintsTransitively(f: Hints => Hints): Schema[A] = transformTransitivelyK(new (Schema ~> Schema) { + def apply[B](fa: Schema[B]): Schema[B] = { + val base = fa.transformHintsLocally(f) + + base match { + case EnumerationSchema(shapeId, hints, tag, values, total) => + EnumerationSchema(shapeId, hints, tag, values.map(_.transformHints(f)), total) + + case other => other + } + } + }) + + def transformTransitivelyK(f: Schema ~> Schema): Schema[A] = compile(new TransitiveCompiler(f)) final def validated[C](c: C)(implicit constraint: RefinementProvider.Simple[C, A]): Schema[A] = { val hint = Hints.Binding.fromValue(c)(constraint.tag) @@ -190,6 +192,53 @@ object Schema { def apply[A](fa: Schema[A]): Schema[A] = fa.transformHintsTransitively(f) } + /** + * Transforms this schema, and all the schemas inside it, using the provided function. + */ + def transformTransitivelyK(f: Schema ~> Schema): Schema ~> Schema = new (Schema ~> Schema) { + def apply[A](fa: Schema[A]): Schema[A] = fa.transformTransitivelyK(f) + } + + // format: on + private final class TransitiveCompiler( + underlying: Schema ~> Schema + ) extends (Schema ~> Schema) { + + def apply[A]( + fa: Schema[A] + ): Schema[A] = fa match { + case e @ EnumerationSchema(_, _, _, _, _) => underlying(e) + case p @ PrimitiveSchema(_, _, _) => underlying(p) + case u @ UnionSchema(_, _, _, _) => + underlying(u.copy(alternatives = u.alternatives.map(handleAlt(_)))) + case BijectionSchema(s, bijection) => + underlying(BijectionSchema(this(s), bijection)) + case LazySchema(suspend) => + underlying(LazySchema(suspend.map(this.apply))) + case RefinementSchema(s, refinement) => + underlying(RefinementSchema(this(s), refinement)) + case c: CollectionSchema[c, a] => + underlying(c.copy(member = this(c.member))) + case m @ MapSchema(_, _, _, _) => + underlying(m.copy(key = this(m.key), value = this(m.value))) + case s @ StructSchema(_, _, _, _) => + underlying(s.copy(fields = s.fields.map(handleField(_)))) + case n @ OptionSchema(_) => + underlying(n.copy(underlying = this(n.underlying))) + } + + private def handleField[S, A]( + field: Field[S, A] + ): Field[S, A] = field.copy(schema = this(field.schema)) + + private def handleAlt[S, A]( + alt: Alt[S, A] + ): Alt[S, A] = alt.copy(schema = this(alt.schema)) + } + + // format: off + + ////////////////////////////////////////////////////////////////////////////////////////////////// // SCHEMA BUILDER //////////////////////////////////////////////////////////////////////////////////////////////////