-
Notifications
You must be signed in to change notification settings - Fork 17
/
lift.jl
86 lines (74 loc) · 2.22 KB
/
lift.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Lifting Functions to Functions
# No Exists{T} yet https://github.com/JuliaLang/julia/issues/21026#issuecomment-306624369"
function liftnoesc(fnm::Union{Symbol, Expr}, isrv::NTuple{N, Bool}) where N
args = [isrv ? :($(Symbol(:x, i))::Omega.RandVar) : Symbol(:x, i) for (i, isrv) in enumerate(isrv)]
quote
function $fnm($(args...))
Omega.ciid($fnm, $(args...))
end
end
end
function liftesc(fnm::Union{Symbol, Expr}, isrv::NTuple{N, Bool}) where N
args = [isrv ? :($(Symbol(:x, i))::Omega.RandVar) : Symbol(:x, i) for (i, isrv) in enumerate(isrv)]
quote
function $(esc(fnm))($(args...))
Omega.ciid($fnm, $(args...))
end
end
end
function lift(fnm::Union{Expr, Symbol}, n::Integer; mod::Module=@__MODULE__())
combs = rvcombinations(n)
for comb in combs
Core.eval(mod, liftnoesc(fnm, comb))
end
end
function lift(f; n=3, mod::Module=@__MODULE__())
lift(:($f), n; mod=mod)
end
## Pre Lifted
## ==========
fnms = [:(Base.:-),
:(Base.:+),
:(Base.:*),
:(Base.:/),
:(Base.:^),
:(Base.:sin),
:(Base.:cos),
:(Base.:tan),
:(Base.sum),
:(Base.:&),
:(Base.:|),
:(Base.:sqrt),
:(Base.:abs),
:(Base.getindex),
:(Base.:(==)),
:(Base.:>),
:(Base.:>=),
:(Base.:<=),
:(Base.:<),
]
# Base.:^(x1::RandVar, x2::MaybeRV) = ciid(^, x1, x2) # FIXME: Only for 0.7 deprecations
# Base.:^(x1::RandVar, x2::Integer) = ciid(^, x1, x2) # FIXME: Only for 0.7 deprecations
macro lift(fnm::Union{Symbol, Expr}, n::Integer)
combinations = Iterators.product(((true,false) for i = 1:n)...)
combinations = Iterators.filter(any, combinations)
Expr(:block, map(comb -> liftmacro(fnm, comb), combinations)...)
end
"Combinations of RV or Not RV"
function rvcombinations(n)
combinations = Iterators.product(((true,false) for i = 1:n)...)
Iterators.filter(any, combinations)
end
const MAXN = 4
for fnm in fnms, i = 1:MAXN
lift(fnm, i)
end
# lift(f::Function) = (args...) -> ciid(f, args...)
@generated function maybelift(f::Function, args...)
if any([arg <: RandVar for arg in args])
:(ciid(f, args...))
else
:(f(args...))
end
end
lift(f::Function) = (args...) -> maybelift(f, args...)