/
KotlingradExpression.kt
59 lines (52 loc) · 2.14 KB
/
KotlingradExpression.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.kotlingrad
import ai.hypergraph.kotlingrad.api.SFun
import ai.hypergraph.kotlingrad.api.SVar
import space.kscience.attributes.SafeType
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.NumericAlgebra
/**
* Represents [MST] based [DifferentiableExpression] relying on [Kotlin∇](https://github.com/breandan/kotlingrad).
*
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
* [SFun] back to [MST].
*
* @param T The type of number.
* @param A The [NumericAlgebra] of [T].
* @property algebra The [A] instance.
* @property mst The [MST] node.
*/
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A,
public val mst: MST,
) : SpecialDifferentiableExpression<T, KotlingradExpression<T, A>> {
override val type: SafeType<T> = algebra.type
override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
override fun derivativeOrNull(
symbols: List<Symbol>,
): KotlingradExpression<T, A> = KotlingradExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstNumericAlgebra::bindSymbol)
.map<Symbol, SVar<KMathNumber<T, A>>>(Symbol::toSVar)
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
}
/**
* A diff processor using [MST] to Kotlingrad converter
*/
public class KotlingradProcessor<T : Number, A : NumericAlgebra<T>>(
public val algebra: A,
) : AutoDiffProcessor<T, MST, MstExtendedField> {
override fun differentiate(function: MstExtendedField.() -> MST): DifferentiableExpression<T> =
MstExtendedField.function().toKotlingradExpression(algebra)
}
/**
* Wraps this [MST] into [KotlingradExpression] in the context of [algebra].
*/
public fun <T : Number, A : NumericAlgebra<T>> MST.toKotlingradExpression(algebra: A): KotlingradExpression<T, A> =
KotlingradExpression(algebra, this)