Skip to content

Commit

Permalink
Merge pull request #1426 from disneystreaming/add-transitive-schema-t…
Browse files Browse the repository at this point in the history
…ransformation

Add transformTransitivelyK
  • Loading branch information
Baccata committed Mar 4, 2024
2 parents 26ba263 + 3960675 commit 8127d63
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 13 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
@@ -1,3 +1,8 @@
# 0.18.11

* Added a utility method, `Schema.transformTransitivelyK`, to help in recursively transforming schemas.
In addition, the semantics of `transformHintsTransitively` have been changed: the transformation no longer modifies the hints on the result of the `total` function.

# 0.18.10

* Bumps alloy to 0.3.1. This is required as otherwise the `alloy#nullable` hints get filtered out when using SimpleRestJsonBuilder.
Expand Down
Expand Up @@ -214,7 +214,12 @@ class HintsTransformationSpec() extends FunSuite {
values: List[EnumValue[E]],
total: E => EnumValue[E]
): Count[E] = { e =>
count(hints) + count(total(e).hints)
count(hints) + count(
values
.find(_.value == total(e).value)
.getOrElse(sys.error("Unknown enum value"))
.hints
)
}

def struct[S](
Expand Down
73 changes: 61 additions & 12 deletions modules/core/src/smithy4s/schema/Schema.scala
Expand Up @@ -66,18 +66,20 @@ sealed trait Schema[A]{
case s: OptionSchema[a] => OptionSchema(s.underlying.transformHintsLocally(f)).asInstanceOf[Schema[A]]
}

final def transformHintsTransitively(f: Hints => Hints): Schema[A] = this match {
case PrimitiveSchema(shapeId, hints, tag) => PrimitiveSchema(shapeId, f(hints), tag)
case s: CollectionSchema[c, a] => CollectionSchema[c, a](s.shapeId, f(s.hints), s.tag, s.member.transformHintsTransitively(f)).asInstanceOf[Schema[A]]
case s: MapSchema[k, v] => MapSchema(s.shapeId, f(s.hints), s.key.transformHintsTransitively(f), s.value.transformHintsTransitively(f)).asInstanceOf[Schema[A]]
case EnumerationSchema(shapeId, hints, tag, values, total) => EnumerationSchema(shapeId, f(hints), tag, values.map(_.transformHints(f)), total andThen (_.transformHints(f)))
case StructSchema(shapeId, hints, fields, make) => StructSchema(shapeId, f(hints), fields.map(_.transformHintsTransitively(f)), make)
case UnionSchema(shapeId, hints, alternatives, dispatch) => UnionSchema(shapeId, f(hints), alternatives.map(_.transformHintsTransitively(f)), dispatch)
case BijectionSchema(schema, bijection) => BijectionSchema(schema.transformHintsTransitively(f), bijection)
case RefinementSchema(schema, refinement) => RefinementSchema(schema.transformHintsTransitively(f), refinement)
case LazySchema(suspend) => LazySchema(suspend.map(_.transformHintsTransitively(f)))
case s: OptionSchema[a] => OptionSchema(s.underlying.transformHintsTransitively(f)).asInstanceOf[Schema[A]]
}
final def transformHintsTransitively(f: Hints => Hints): Schema[A] = transformTransitivelyK(new (Schema ~> Schema) {
def apply[B](fa: Schema[B]): Schema[B] = {
val base = fa.transformHintsLocally(f)

base match {
case EnumerationSchema(shapeId, hints, tag, values, total) =>
EnumerationSchema(shapeId, hints, tag, values.map(_.transformHints(f)), total)

case other => other
}
}
})

def transformTransitivelyK(f: Schema ~> Schema): Schema[A] = compile(new TransitiveCompiler(f))

final def validated[C](c: C)(implicit constraint: RefinementProvider.Simple[C, A]): Schema[A] = {
val hint = Hints.Binding.fromValue(c)(constraint.tag)
Expand Down Expand Up @@ -190,6 +192,53 @@ object Schema {
def apply[A](fa: Schema[A]): Schema[A] = fa.transformHintsTransitively(f)
}

/**
* Transforms this schema, and all the schemas inside it, using the provided function.
*/
def transformTransitivelyK(f: Schema ~> Schema): Schema ~> Schema = new (Schema ~> Schema) {
def apply[A](fa: Schema[A]): Schema[A] = fa.transformTransitivelyK(f)
}

// format: on
private final class TransitiveCompiler(
underlying: Schema ~> Schema
) extends (Schema ~> Schema) {

def apply[A](
fa: Schema[A]
): Schema[A] = fa match {
case e @ EnumerationSchema(_, _, _, _, _) => underlying(e)
case p @ PrimitiveSchema(_, _, _) => underlying(p)
case u @ UnionSchema(_, _, _, _) =>
underlying(u.copy(alternatives = u.alternatives.map(handleAlt(_))))
case BijectionSchema(s, bijection) =>
underlying(BijectionSchema(this(s), bijection))
case LazySchema(suspend) =>
underlying(LazySchema(suspend.map(this.apply)))
case RefinementSchema(s, refinement) =>
underlying(RefinementSchema(this(s), refinement))
case c: CollectionSchema[c, a] =>
underlying(c.copy(member = this(c.member)))
case m @ MapSchema(_, _, _, _) =>
underlying(m.copy(key = this(m.key), value = this(m.value)))
case s @ StructSchema(_, _, _, _) =>
underlying(s.copy(fields = s.fields.map(handleField(_))))
case n @ OptionSchema(_) =>
underlying(n.copy(underlying = this(n.underlying)))
}

private def handleField[S, A](
field: Field[S, A]
): Field[S, A] = field.copy(schema = this(field.schema))

private def handleAlt[S, A](
alt: Alt[S, A]
): Alt[S, A] = alt.copy(schema = this(alt.schema))
}

// format: off


//////////////////////////////////////////////////////////////////////////////////////////////////
// SCHEMA BUILDER
//////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 8127d63

Please sign in to comment.