Skip to content

Commit

Permalink
switch to PythonCall over PyCall
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Apr 2, 2023
1 parent dafec3d commit 1334fae
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 117 deletions.
13 changes: 7 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ PlotUtils = "995b91a9-d308-5afd-9ec6-746e21dbc043"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand All @@ -73,7 +72,9 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MuseInference = "43b88160-90c7-4f71-933b-9d65205cd921"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"


# should duplicate [weakdeps]
# see https://pkgdocs.julialang.org/dev/creating-packages/#Using-an-extension-while-supporting-older-Julia-versions
Expand All @@ -82,12 +83,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MuseInference = "43b88160-90c7-4f71-933b-9d65205cd921"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[extensions]
CMBLensingCUDAExt = "CUDA"
CMBLensingMuseInferenceExt = "MuseInference"
CMBLensingPyPlotExt = "PyPlot"
CMBLensingPythonCallExt = "PythonCall"
CMBLensingPythonPlotExt = "PythonPlot"

[compat]
AbstractFFTs = "0.5, 1"
Expand Down Expand Up @@ -129,8 +132,6 @@ PDMats = "0.11.5"
PlotUtils = "1.3.2"
Preferences = "1.2"
ProgressMeter = "1.2"
PyCall = "1.91.2"
PyPlot = "2"
QuadGK = "2.3.1"
RecipesBase = "1.3.2"
Requires = "0.5, 1"
Expand Down
1 change: 1 addition & 0 deletions ext/CMBLensingCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
module CMBLensingCUDAExt

using CMBLensing

if isdefined(Base, :get_extension)
using CUDA
using CUDA.CUSPARSE: CuSparseMatrix, CuSparseMatrixCSR, CuSparseMatrixCOO
Expand Down
1 change: 1 addition & 0 deletions ext/CMBLensingMuseInferenceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
module CMBLensingMuseInferenceExt

using CMBLensing

if isdefined(Base, :get_extension)
using MuseInference
using MuseInference: AD, AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ
Expand Down
16 changes: 16 additions & 0 deletions ext/CMBLensingPythonCallExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

module CMBLensingPythonCallExt

using CMBLensing
using CMBLensing: extrapolate_Cℓs

if isdefined(Base, :get_extension)
import PythonCall
else
import ..PythonCall
end

CMBLensing.pyimport(x) = PythonCall.pyimport(x)
CMBLensing.PyArray(x) = PythonCall.PyArray(x)

end
41 changes: 20 additions & 21 deletions ext/CMBLensingPyPlotExt.jl → ext/CMBLensingPythonPlotExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module CMBLensingPyPlotExt
module CMBLensingPythonPlotExt

using CMBLensing

if isdefined(Base, :get_extension)
using PyPlot
using PyPlot.PyCall
using PythonPlot
using PythonPlot.PythonCall
else
using ..PyPlot
using ..PyPlot.PyCall
using ..PythonPlot
using ..PythonPlot.PythonCall
end

using FFTW
Expand All @@ -17,7 +18,7 @@ using StatsBase

### overloaded 1D plotting

for plot in (:(PyPlot.plot), :(PyPlot.loglog), :(PyPlot.semilogx), :(PyPlot.semilogy))
for plot in (:(PythonPlot.plot), :(PythonPlot.loglog), :(PythonPlot.semilogx), :(PythonPlot.semilogy))

# Cℓs
@eval function ($plot)(ic::Cℓs, args...; kwargs...)
Expand Down Expand Up @@ -51,7 +52,7 @@ for plot in (:(PyPlot.plot), :(PyPlot.loglog), :(PyPlot.semilogx), :(PyPlot.semi
end

# 2D KDE
function PyPlot.plot(k::CMBLensing.GetDistKDE{2}, args...; color=nothing, label=nothing, levels=[0.95,0.68], filled=true, kwargs...)
function PythonPlot.plot(k::CMBLensing.GetDistKDE{2}, args...; color=nothing, label=nothing, levels=[0.95,0.68], filled=true, kwargs...)
@unpack colors = pyimport("matplotlib")
args = k.kde.x, k.kde.y, k.kde.P, [k.kde.getContourLevels(levels); Inf]
if color == nothing
Expand All @@ -62,7 +63,7 @@ function PyPlot.plot(k::CMBLensing.GetDistKDE{2}, args...; color=nothing, label=
end

# Cℓ band
function PyPlot.fill_between(ic::Cℓs{<:Measurement}, args...; kwargs...)
function PythonPlot.fill_between(ic::Cℓs{<:Measurement}, args...; kwargs...)
fill_between(
ic.ℓ,
((@. Measurements.value(ic.Cℓ) - x * Measurements.uncertainty(ic.Cℓ)) for x in (-1,1))...,
Expand Down Expand Up @@ -161,11 +162,9 @@ function _plot(f, ax, k, title, vlim, vscale, cmap; cbar=true, units=:deg, tickl
ax.set_title(title, y=1)
if ticklabels
if ismap
@pydef mutable struct MyFmt <: pyimport(:matplotlib).ticker.ScalarFormatter
__call__(self,v,p=nothing) = py"super"(MyFmt,self).__call__(v,p)*Dict(:deg=>"°",:arcmin=>"")[units]
end
ax.xaxis.set_major_formatter(MyFmt())
ax.yaxis.set_major_formatter(MyFmt())
u = Dict(:deg=>"°", :arcmin=>"")[units]
ax.xaxis.set_major_formatter("{x}"*u)
ax.yaxis.set_major_formatter("{x}"*u)
if axeslabels
if f isa LambertField
ax.set_xlabel("x")
Expand Down Expand Up @@ -194,8 +193,8 @@ end
Plotting fields.
"""
PyPlot.plot(f::Field; kwargs...) = plot([f]; kwargs...)
function PyPlot.plot(D::DiagOp; kwargs...)
PythonPlot.plot(f::Field; kwargs...) = plot([f]; kwargs...)
function PythonPlot.plot(D::DiagOp; kwargs...)
props = _sub_components[findfirst(((k,v),)->diag(D) isa @eval($k), _sub_components)][2]
plot(
[diag(D)];
Expand All @@ -204,7 +203,7 @@ function PyPlot.plot(D::DiagOp; kwargs...)
)
end

function PyPlot.plot(
function PythonPlot.plot(
fs :: AbstractVecOrMat{F};
plotsize = 4,
which = default_which(fs),
Expand All @@ -222,8 +221,8 @@ function PyPlot.plot(
aspect = all('x' in string(w) for w in [""] .* string.(which)) ? fs[1].Nx / fs[1].Ny : 1
end
figsize = plotsize .* [1.4 * n * aspect, m]
fig,axs = subplots(m, n; figsize, squeeze=false)
axs = getindex.(Ref(axs), 1:m, (1:n)') # see https://github.com/JuliaPy/PyCall.jl/pull/487#issuecomment-456998345
fig, axs = subplots(m, n; figsize, squeeze=false)
axs = getindex.(Ref(PyArray(axs)), 1:m, (1:n)') # see https://github.com/JuliaPy/PythonCall.jl/pull/487#issuecomment-456998345
_plot.(fs,axs,which,title,vlim,vscale,cmap; kwargs...)

if return_all
Expand Down Expand Up @@ -288,8 +287,8 @@ end

### plotting HealpixFields

function PyPlot.plot(f::HealpixMap; kwargs...)
hp.projview(
function PythonPlot.plot(f::HealpixMap; kwargs...)
pyimport("healpy").projview(
collect(f.arr);
cmap = "RdBu_r",
graticule = true,
Expand All @@ -312,7 +311,7 @@ end
### convenience
# for plotting in environments that only show a plot if its the last thing returned

function PyPlot.figure(plotfn::Function, args...; kwargs...)
function PythonPlot.figure(plotfn::Function, args...; kwargs...)
figure(args...; kwargs...)
plotfn()
gcf()
Expand Down
3 changes: 2 additions & 1 deletion src/CMBLensing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ include("autodiff.jl")
end
# some stubs filled in by extensions
function animate end

function pyimport end
function PyArray end

# misc init
# see https://github.com/timholy/ProgressMeter.jl/issues/71 and links therein
Expand Down
6 changes: 3 additions & 3 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ bootstrap resampling using the calculated "effective sample size" of the chain.
"""
function mean_std_and_errors(samples; N_bootstrap=10000, N_in_paren=2, tol=50)

Neff = round(Int, length(samples) / @ondemand(PyCall.pyimport)(:emcee).autocorr.integrated_time(samples; tol)[1])
Neff = round(Int, length(samples) / PyArray(pyimport("emcee").autocorr.integrated_time(samples; tol))[1])

μ = mean(samples)
σ = std(samples)
Expand Down Expand Up @@ -234,7 +234,7 @@ Based on Python [GetDist](https://getdist.readthedocs.io/en/latest/intro.html),
which must be installed.
"""
function kde(samples::AbstractVector; boundary=(nothing,nothing), normalize="integral", smooth_scale_1D=nothing)
getdist = @ondemand(PyCall.pyimport)("getdist")
getdist = pyimport("getdist")
getdist.chains.print_load_details = false
kde = getdist.MCSamples(;
samples, weights=nothing, names=["x"], ranges=Dict("x"=>boundary)
Expand All @@ -249,7 +249,7 @@ function kde(samples::AbstractMatrix; boundary=((nothing,nothing),(nothing,nothi
elseif size(samples,2) != 2
error("KDE only supports 1 or 2 dimensional samples.")
end
getdist = @ondemand(PyCall.pyimport)("getdist")
getdist = pyimport("getdist")
getdist.chains.print_load_details = false
kde = getdist.MCSamples(;
samples, weights=nothing, names=["x","y"], ranges=Dict("x"=>boundary[1], "y"=>boundary[2])
Expand Down
96 changes: 46 additions & 50 deletions src/cls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ end

### camb interface (via Python/pycamb)

_default_Cℓs_path = joinpath(@__DIR__, "../dat/default_camb_Cls.jld2")
_default_Cℓs_params = isfile(_default_Cℓs_path) ? Dict(pairs(load(_default_Cℓs_path, "params"))) : Dict()
const _default_Cℓs_path = joinpath(@__DIR__, "../dat/default_camb_Cls.jld2")
const _default_Cℓs_params = isfile(_default_Cℓs_path) ? Dict(pairs(load(_default_Cℓs_path, "params"))) : Dict()
@memoize _default_Cℓs() = load(_default_Cℓs_path, "Cℓ")

@memoize function camb(;
Expand All @@ -150,60 +150,56 @@ _default_Cℓs_params = isfile(_default_Cℓs_path) ? Dict(pairs(load(_default_C
return _default_Cℓs()
end

camb = @ondemand(PyCall.pyimport)(:camb)
camb = pyimport("camb")

Base.invokelatest() do

ℓmax′ = min(5000,ℓmax)
cp = camb.set_params(
ombh2 = ωb,
omch2 = ωc,
tau = τ,
mnu = Σmν,
cosmomc_theta = θs,
H0 = H0,
ns = nₛ,
nt = nₜ,
As = exp(logA)*1e-10,
pivot_scalar = k_pivot,
pivot_tensor = k_pivot,
lmax = ℓmax′,
r = r,
Alens = AL,
)
cp.max_l_tensor = ℓmax′
cp.max_eta_k_tensor = 2ℓmax′
cp.WantScalars = true
cp.WantTensors = true
cp.DoLensing = true
cp.set_nonlinear_lensing(true)

res = camb.get_results(cp)


= collect(2:ℓmax -1)
ℓ′ = collect(2:ℓmax′-1)
α = (10^6*cp.TCMB)^2
toCℓ′ = @. 1/(ℓ′*(ℓ′+1)/(2π))

Cℓϕϕ = extrapolate_Cℓs(ℓ,ℓ′,2π*res.get_lens_potential_cls(ℓmax′)[3:ℓmax′,1]./ℓ′.^4)
ℓmax′ = min(5000, ℓmax)
cp = camb.set_params(
ombh2 = ωb,
omch2 = ωc,
tau = τ,
mnu = Σmν,
cosmomc_theta = θs,
H0 = H0,
ns = nₛ,
nt = nₜ,
As = exp(logA) * 1e-10,
pivot_scalar = k_pivot,
pivot_tensor = k_pivot,
lmax = ℓmax′,
r = r,
Alens = AL,
)
cp.max_l_tensor = ℓmax′
cp.max_eta_k_tensor = 2ℓmax′
cp.WantScalars = true
cp.WantTensors = true
cp.DoLensing = true
cp.set_nonlinear_lensing(true)

return (;
map(["unlensed_scalar","lensed_scalar","tensor","unlensed_total","total"]) do k
Symbol(k) => (;
map(enumerate([:TT,:EE,:BB,:TE])) do (i,x)
Symbol(x) => extrapolate_Cℓs(ℓ,ℓ′,res.get_cmb_power_spectra()[k][3:ℓmax′,i].*toCℓ′.*α)
end...,
ϕϕ = Cℓϕϕ
)
end...,
params = (;params...)
)
res = camb.get_results(cp)

end
= collect(2:ℓmax -1)
ℓ′ = collect(2:ℓmax′-1)
α = (10^6 * cp.TCMB)^2
toCℓ′ = @. 1/(ℓ′*(ℓ′+1)/(2π))

Cℓϕϕ = extrapolate_Cℓs(ℓ, ℓ′, 2π * PyArray(res.get_lens_potential_cls(ℓmax′))[3:ℓmax′,1] ./ ℓ′.^4)

return (;
map(["unlensed_scalar","lensed_scalar","tensor","unlensed_total","total"]) do k
Symbol(k) => (;
map(enumerate([:TT,:EE,:BB,:TE])) do (i,x)
Symbol(x) => extrapolate_Cℓs(ℓ, ℓ′, PyArray* res.get_cmb_power_spectra()[k])[3:ℓmax′,i] .* toCℓ′)
end...,
ϕϕ = Cℓϕϕ
)
end...,
params = (;params...)
)

end


@doc """
load_camb_Cℓs(;path_prefix, custom_tensor_params=nothing,
unlensed_scalar_postfix, unlensed_tensor_postfix, lensed_scalar_postfix, lenspotential_postfix)
Expand Down
5 changes: 0 additions & 5 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,3 @@ getindex(x::Union{Real,Field,FieldOp}, ::typeof(!), I) = batch_index(x, I)
one(f::Field) = fill!(similar(f), one(eltype(f)))
norm(f::Field) = sqrt(dot(f,f))
# sum_kbn(f::Field) = sum_kbn(f[:])

@init @require PyCall="438e738f-606a-5dbb-bf0a-cddfbfd45ab0" begin
# never try to auto-convert Fields or FieldOps to Python arrays
PyCall.PyObject(x::Union{FieldOp,Field}) = PyCall.pyjlwrap_new(x)
end
4 changes: 1 addition & 3 deletions src/proj_healpix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# The main functionality of broadcasting, indexing, and projection for
# a few field types is implemented, but not much beyond that.

@init global hp = lazy_pyimport("healpy")

struct ProjHealpix <: Proj
Nside :: Int
end
Expand Down Expand Up @@ -224,7 +222,7 @@ function project(projector::Projector{:bilinear}, (hpx_map, cart_proj)::Pair{<:H
@assert projector.hpx_proj == hpx_map.proj && projector.cart_proj == cart_proj
@unpack (Ny, Nx, T) = cart_proj
@unpack (θs, ϕs) = projector
BaseMap(T.(reshape(hp.get_interp_val(collect(hpx_map), θs, ϕs), Ny, Nx)), cart_proj)
BaseMap(T.(reshape(PyArray(pyimport("healpy").get_interp_val(collect(hpx_map), θs, ϕs)), Ny, Nx)), cart_proj)
end

function project(projector::Projector{:fft}, (hpx_map, cart_proj)::Pair{<:HealpixMap,<:CartesianProj})
Expand Down

0 comments on commit 1334fae

Please sign in to comment.