-
Notifications
You must be signed in to change notification settings - Fork 17
/
Prim.jl
109 lines (93 loc) · 2.34 KB
/
Prim.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
module Prim
using ..Omega
using ..Omega: Ω, RandVar, URandVar, MaybeRV, ID, lift, uid, elemtype, isconstant
import ..Omega: params, name, ppapl, apl, reify
import Statistics: mean, var, quantile
import ..Causal: ReplaceRandVar
using ..Util
using Spec
import Distributions
const Djl = Distributions
import Base: minimum, maximum
export bernoulli,
betarv,
β,
categorical,
# dirichlet,
exponential,
gammarv,
Γ,
invgamma,
kumaraswamy,
logistic,
# mvnormal,
normal,
poisson,
rademacher,
uniform,
mean
"Primitive random variable of known distribution"
abstract type PrimRandVar <: RandVar end
"Name of a distribution"
function name end
name(t::T) where {T <: PrimRandVar} = T.name.name
"Parameters of `rv`"
@generated function params(rv::PrimRandVar)
fields = [Expr(:., :rv, QuoteNode(f)) for f in fieldnames(rv) if f !== :id]
Expr(:tuple, fields...)
end
ppapl(rv::PrimRandVar, ωπ) = rvtransform(rv)(ωπ, reify(ωπ, params(rv))...)
@generated function anysize(args::Union{<:AbstractArray, Real}...)
isarr = (arg -> arg <: AbstractArray).([args...])
firstarr = findfirst(isarr)
if isempty(isarr)
:(())
else
:(size(args[$firstarr]))
end
end
@spec same(size.(filter(a -> a isa AbstractArray, args)))
include("univariate.jl") # Univariate Distributions
include("multivariate.jl") # Multivariate Distributions
include("statistics.jl") # Distributional properties: mean, variance, etc
export succprob,
failprob,
maximum,
minimum,
islowerbounded,
isupperbounded,
isbounded,
std,
median,
mode,
modes,
skewness,
kurtosis,
isplatykurtic,
ismesokurtic,
isleptokurtic,
entropy,
mean,
prob,
lprob
# Lifted distributional functions
export lsuccprob,
lfailprob,
lmaximum,
lminimum,
lislowerbounded,
lisupperbounded,
lisbounded,
lstd,
lmedian,
lmode,
lmodes,
lskewness,
lkurtosis,
lisplatykurtic,
lismesokurtic,
lisleptokurtic,
lentropy,
lmean
include("djl.jl") # Distributions.jl interop
end