From 912c6ebcacfeba43f4f6b401c0285cf9f38c029c Mon Sep 17 00:00:00 2001 From: Kibwe Tavares Date: Fri, 5 Oct 2018 11:38:06 -0400 Subject: [PATCH 1/3] add lprob --- src/Omega.jl | 12 +++++++----- src/higher/Higher.jl | 2 +- src/higher/rcd.jl | 3 ++- src/lift/lift.jl | 4 ++-- src/primitive/Prim.jl | 6 ++++-- src/primitive/statistics.jl | 18 +++++++----------- 6 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/Omega.jl b/src/Omega.jl index e4234435..487f5c91 100644 --- a/src/Omega.jl +++ b/src/Omega.jl @@ -32,15 +32,15 @@ export RandVar, MaybeRV, ciid, isconstant, elemtype, params include("cond.jl") # Conditioning export cond +# Lifted random variable operatiosn +include("lift/containers.jl") # Array/Tuple primitives +export randarray, randtuple + # Higher-Order Inference include("higher/Higher.jl") using .Higher export rcd, rid, ∥ -# Lifted random variable operatiosn -include("lift/containers.jl") # Array/Tuple primitives -export randarray, randtuple - # Lifting functions to RandVar domain include("lift/lift.jl") export @lift, lift @@ -157,7 +157,9 @@ export succprob, isleptokurtic, entropy, - mean + mean, + prob, + lprob # Lifted distributional functions export lsuccprob, diff --git a/src/higher/Higher.jl b/src/higher/Higher.jl index 97471177..6c34e251 100644 --- a/src/higher/Higher.jl +++ b/src/higher/Higher.jl @@ -1,6 +1,6 @@ module Higher -using ..Omega: RandVar, ciid +using ..Omega: RandVar, ciid, randtuple, cond include("rcd.jl") # Random Conditional Distribution include("rid.jl") # Random Interventional Distribution diff --git a/src/higher/rcd.jl b/src/higher/rcd.jl index 380ddc26..565b352d 100644 --- a/src/higher/rcd.jl +++ b/src/higher/rcd.jl @@ -1,5 +1,6 @@ "Random Conditional Distribution" -rcd(x::RandVar, θ::RandVar, eq = ==ₛ) = ciid(ω -> cond(x, eq(θ, θ(ω)))) +rcd(x::RandVar, θ::RandVar, eq = ==) = ciid(ω -> cond(x, eq(θ, θ(ω)))) +rcd(x::RandVar, θs::Tuple, eq = ==) = rcd(x, randtuple(θs), eq) "`rcd`, x ∥ y" x ∥ y = rcd(x, y) \ No newline at end of file diff --git a/src/lift/lift.jl b/src/lift/lift.jl index 96a88074..90449072 100644 --- a/src/lift/lift.jl +++ b/src/lift/lift.jl @@ -54,8 +54,8 @@ fnms = [:(Base.:-), :(Base.:<), ] -Base.:^(x1::RandVar, x2::MaybeRV) = ciid(^, x1, x2) # FIXME: Only for 0.7 deprecations -Base.:^(x1::RandVar, x2::Integer) = ciid(^, x1, x2) # FIXME: Only for 0.7 deprecations +# Base.:^(x1::RandVar, x2::MaybeRV) = ciid(^, x1, x2) # FIXME: Only for 0.7 deprecations +# Base.:^(x1::RandVar, x2::Integer) = ciid(^, x1, x2) # FIXME: Only for 0.7 deprecations macro lift(fnm::Union{Symbol, Expr}, n::Integer) combinations = Iterators.product(((true,false) for i = 1:n)...) combinations = Iterators.filter(any, combinations) diff --git a/src/primitive/Prim.jl b/src/primitive/Prim.jl index 8bb94ca0..eee111fa 100644 --- a/src/primitive/Prim.jl +++ b/src/primitive/Prim.jl @@ -9,7 +9,7 @@ using ..Util using Spec import Distributions const Djl = Distributions -import Base: minimum +import Base: minimum, maximum export bernoulli, betarv, @@ -79,7 +79,9 @@ export succprob, isleptokurtic, entropy, - mean + mean, + prob, + lprob # Lifted distributional functions export lsuccprob, diff --git a/src/primitive/statistics.jl b/src/primitive/statistics.jl index 7571c09f..7fcf4eec 100644 --- a/src/primitive/statistics.jl +++ b/src/primitive/statistics.jl @@ -16,17 +16,14 @@ # end # mean(xs::RandVar{<:Array}) = RandVar{Float64, false}(mean, (xs,)) +"Sample Mean" +mean(x::RandVar, n) = sum((rand(x, alg = RejectionSample) for i = 1:n)) / n -# "Probability that `x` is `true`" -# prob(x::RandVar{T}, n) where {T <: Bool} = mean(x, n) -# prob(x::RandVar{T}, n = 10000) where { T<: RandVar{Bool}} = RandVar{Float64}(prob, (x, n)) -# lift(:prob, 1) - - -# Issues. -# Must expect that type inference may fail, and allow providing of type -# Also allow separate functions -# +"Probability x is true" +prob(x::RandVar, n, israndvar::Type{Val{false}}) = + sum((rand(x, alg = RejectionSample) for i = 1:n)) / n +lprob(x::RandVar, n) = ciid(prob, x, n, Val{false}) +prob(x::RandVar, n) = prob(x, n, Val{elemtype(x) <: RandVar}) # Specializations const unidistattrs = [:succprob, :failprob, :maximum, :minimum, :islowerbounded, @@ -42,7 +39,6 @@ for func in unidistattrs $(:l *ₛ func)(x::RandVar) = ciid($func, x, Val{false}) $func(x::RandVar) = $func(x, Val{elemtype(x) <: RandVar}) end - @show expr eval(expr) end From 4d0f3852494a6fee0513f18a9e4e38f2e5886e23 Mon Sep 17 00:00:00 2001 From: Zenna Tavares Date: Mon, 8 Oct 2018 01:20:46 -0400 Subject: [PATCH 2/3] dependencies and cleanups statistics --- Manifest.toml | 14 ++++++-------- Project.toml | 3 +-- src/primitive/statistics.jl | 4 +++- src/soft/soft.jl | 1 + 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index b0514111..c279cc03 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -206,10 +206,10 @@ deps = ["Printf"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" [[ProgressMeter]] -deps = ["Printf", "Random", "Test"] -git-tree-sha1 = "09e653f4b0a3c44628f0bdd0a0e58bc92e0264ef" +deps = ["Distributed", "Printf", "Random", "Test"] +git-tree-sha1 = "5b55c2c974084eab2689ec0d1d5245561b22aeb0" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "0.6.0" +version = "0.6.1" [[QuadGK]] deps = ["DataStructures", "LinearAlgebra", "Test"] @@ -267,11 +267,9 @@ deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[Spec]] -deps = ["Cassette"] -git-tree-sha1 = "275d533031b8c97b6ca38b83395a435524a3dc46" -repo-rev = "master" -repo-url = "https://github.com/zenna/Spec.jl" -uuid = "526a04b8-654b-11e8-1588-db2a414f95b5" +deps = ["Cassette", "Test"] +git-tree-sha1 = "f1839be7daf5bf850406255be00490273eb380ee" +uuid = "b8ccf107-3a88-5e0f-823b-b838c6a0f327" version = "0.1.3" [[SpecialFunctions]] diff --git a/Project.toml b/Project.toml index 257250db..d16b7474 100644 --- a/Project.toml +++ b/Project.toml @@ -7,10 +7,9 @@ version = "0.1.0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Spec = "526a04b8-654b-11e8-1588-db2a414f95b5" +Spec = "b8ccf107-3a88-5e0f-823b-b838c6a0f327" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" diff --git a/src/primitive/statistics.jl b/src/primitive/statistics.jl index 7fcf4eec..df67127a 100644 --- a/src/primitive/statistics.jl +++ b/src/primitive/statistics.jl @@ -22,9 +22,11 @@ mean(x::RandVar, n) = sum((rand(x, alg = RejectionSample) for i = 1:n)) / n "Probability x is true" prob(x::RandVar, n, israndvar::Type{Val{false}}) = sum((rand(x, alg = RejectionSample) for i = 1:n)) / n -lprob(x::RandVar, n) = ciid(prob, x, n, Val{false}) +lprob(x::RandVar, n = 1000) = ciid(prob, x, n, Val{false}) prob(x::RandVar, n) = prob(x, n, Val{elemtype(x) <: RandVar}) +# const lmean = lprob + # Specializations const unidistattrs = [:succprob, :failprob, :maximum, :minimum, :islowerbounded, :isupperbounded, :isbounded, :std, :median, :mode, :modes, diff --git a/src/soft/soft.jl b/src/soft/soft.jl index 418cd311..5e739f3e 100644 --- a/src/soft/soft.jl +++ b/src/soft/soft.jl @@ -26,6 +26,7 @@ function d end @inline d(x::Real, y::Real) = (xy = (x - y); xy * xy) # @inline d(x::Vector{<:Real}, y::Vector{<:Real}) = norm(x - y) @inline d(x::Vector{<:Real}, y::Vector{<:Real}) = sum(d.(x,y)) +@inline d(x::NTuple{N, <: Real}, y::NTuple{N, <:Real}) where N = sum(d.(x,y)) @inline d(x::Array{<:Real}, y::Array{<:Real}) = norm(x[:] - y[:]) "Soft Equality" From 3fdbcb857c50cfcb7ac2ce2f663ca99704c764c9 Mon Sep 17 00:00:00 2001 From: Zenna Tavares Date: Mon, 8 Oct 2018 01:37:21 -0400 Subject: [PATCH 3/3] rm 0.7 support --- .travis.yml | 1 - REQUIRE | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index d7e263c5..9c94d0d8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,6 @@ os: - osx julia: - 1.0 - - 0.7 - nightly notifications: email: false diff --git a/REQUIRE b/REQUIRE index 3f679d14..224d501c 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,4 +1,4 @@ -julia 0.7 +julia 1.0 Flux Distributions PDMats