Skip to content

Commit

Permalink
fix CUDA extension
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Mar 30, 2023
1 parent cccbd30 commit 14c6c32
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ TimerOutputs = "0.5"
Tullio = "0.3"
UnPack = "1"
Zygote = "0.6.21"
julia = "1.7"
julia = "1.9"
54 changes: 30 additions & 24 deletions ext/CMBLensingCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@

module CMBLensingCUDAExt

using Adapt
using AbstractFFTs
using CMBLensing
using CUDA
using CUDA.CUSPARSE: CuSparseMatrix, CuSparseMatrixCSR, CuSparseMatrixCOO
using EllipsisNotation
using ForwardDiff
using ForwardDiff: Dual, Partials, value, partials
using LinearAlgebra
using Markdown
using Random
using SparseArrays
Expand All @@ -16,31 +18,22 @@ using Zygote
const CuBaseField{B,M,T,A<:CuArray} = BaseField{B,M,T,A}

# printing
typealias(::Type{CuArray{T,N}}) where {T,N} = "CuArray{$T,$N}"
CMBLensing.typealias(::Type{CuArray{T,N}}) where {T,N} = "CuArray{$T,$N}"
Base.print_array(io::IO, X::Diagonal{<:Any,<:CuBaseField}) = Base.print_array(io, cpu(X))


# a function version of @cuda which can be referenced before CUDA is
# loaded as long as it exists by run-time (unlike the macro @cuda which must
# exist at compile-time)
function cuda(f, args...; threads=256)
function CMBLensing.cuda(f, args...; threads=256)
@cuda threads=threads f(args...)
end

CMBLensing.is_gpu_backed(::BaseField{B,M,T,A}) where {B,M,T,A<:CuArray} = true
CMBLensing.gpu(x) = Adapt.adapt_structure(CuArray, x)

# handy conversion functions and macros
@doc doc"""

gpu(x)
Recursively moves x to GPU, but unlike `CUDA.cu`, doesn't also convert
to Float32. Equivalent to `adapt_structure(CuArray, x)`.
"""
gpu(x) = adapt_structure(CuArray, x)


function Cℓ_to_2D(Cℓ, proj::ProjLambert{T,<:CuArray}) where {T}
function CMBLensing.Cℓ_to_2D(Cℓ, proj::ProjLambert{T,<:CuArray}) where {T}
# todo: remove needing to go through cpu here:
gpu(T.(nan2zero.(Cℓ.(cpu(proj.ℓmag)))))
end
Expand All @@ -49,17 +42,16 @@ end
### misc
# the generic versions of these trigger scalar indexing of CUDA, so provide
# specialized versions:
pinv(D::Diagonal{T,<:CuBaseField}) where {T} = Diagonal(@. ifelse(isfinite(inv(D.diag)), inv(D.diag), $zero(T)))
inv(D::Diagonal{T,<:CuBaseField}) where {T} = any(Array((D.diag.==0)[:])) ? throw(SingularException(-1)) : Diagonal(inv.(D.diag))
fill!(f::CuBaseField, x) = (fill!(f.arr,x); f)
sum(f::CuBaseField; dims=:) = (dims == :) ? sum_dropdims(f.arr) : (1 in dims) ? error("Sum over invalid dims of CuFlatS0.") : f
LinearAlgebra.pinv(D::Diagonal{T,<:CuBaseField}) where {T} = Diagonal(@. ifelse(isfinite(inv(D.diag)), inv(D.diag), $zero(T)))
LinearAlgebra.inv(D::Diagonal{T,<:CuBaseField}) where {T} = any(Array((D.diag.==0)[:])) ? throw(SingularException(-1)) : Diagonal(inv.(D.diag))
Base.fill!(f::CuBaseField, x) = (fill!(f.arr,x); f)
Base.sum(f::CuBaseField; dims=:) = (dims == :) ? CMBLensing.sum_dropdims(f.arr) : (1 in dims) ? error("Sum over invalid dims of CuFlatS0.") : f

# adapting of SparseMatrixCSC ↔ CuSparseMatrixCSR (otherwise dense arrays created)
adapt_structure(::Type{<:CuArray}, L::SparseMatrixCSC) = CuSparseMatrixCSR(L)
adapt_structure(::Type{<:Array}, L::CuSparseMatrixCSR) = SparseMatrixCSC(L)
adapt_structure(::Type{<:CuArray}, L::CuSparseMatrixCSR) = L
adapt_structure(::Type{<:Array}, L::SparseMatrixCSC) = L

Adapt.adapt_structure(::Type{<:CuArray}, L::SparseMatrixCSC) = CuSparseMatrixCSR(L)
Adapt.adapt_structure(::Type{<:Array}, L::CuSparseMatrixCSR) = SparseMatrixCSC(L)
Adapt.adapt_structure(::Type{<:CuArray}, L::CuSparseMatrixCSR) = L
Adapt.adapt_structure(::Type{<:Array}, L::SparseMatrixCSC) = L

# some Random API which CUDA doesn't implement yet
Random.randn(rng::CUDA.CURAND.RNG, T::Random.BitFloatType) =
Expand All @@ -77,13 +69,13 @@ Random.randn!(rng::MersenneTwister, A::CuArray) =
Gargbage collect and reclaim GPU memory (technically should never be
needed to do this by hand, but sometimes helps with GPU OOM errors)
"""
function cuda_gc()
function CMBLensing.cuda_gc()
isdefined(Main,:Out) && empty!(Main.Out)
GC.gc(true)
CUDA.reclaim()
end

unsafe_free!(x::CuArray) = CUDA.unsafe_free!(x)
CMBLensing.unsafe_free!(x::CuArray) = CUDA.unsafe_free!(x)

@static if CMBLensing.versionof(Zygote)>v"0.6.11"
# https://github.com/JuliaGPU/CUDA.jl/issues/982
Expand Down Expand Up @@ -120,4 +112,18 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::CuArray, dual, i
return result
end

# fix for https://github.com/jonniedie/ComponentArrays.jl/issues/193
function Base.reshape(a::CuArray{T,M}, dims::Tuple{}) where {T,M}
if prod(dims) != length(a)
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $(size(a))"))
end

if 0 == M && dims == size(a)
return a
end

CUDA._derived_array(T, 0, a, dims)
end


end
2 changes: 1 addition & 1 deletion src/proj_lambert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ promote_metadata_generic(metadata₁::ProjLambert, metadata₂::ProjLambert) =
# broadcast, thus avoiding allocating any temporary arrays.

function preprocess((_,proj)::Tuple{<:BaseFieldStyle,<:ProjLambert{T,V}}, r::Real) where {T,V}
r isa BatchedReal ? adapt(V, reshape(r.vals, 1, 1, 1, :)) : r
r isa BatchedReal ? adapt(basetype(V), reshape(r.vals, 1, 1, 1, :)) : r
end
# need custom adjoint here bc Δ can come back batched from the
# backward pass even though r was not batched on the forward pass
Expand Down
22 changes: 19 additions & 3 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ macro cpu!(vars...)
:(begin; $((:($(esc(var)) = cpu($(esc(var)))) for var in vars)...); nothing; end)
end


# stubs filled in by extension module:
@doc doc"""
gpu(x)
Expand All @@ -265,8 +265,8 @@ this does not change the `eltype` of any underlying arrays. See also
[`cpu`](@ref).
"""
function gpu end

function cuda_gc end
function cuda end
is_gpu_backed(x) = false


Expand Down Expand Up @@ -499,4 +499,20 @@ end
ensure_dense(vec::AbstractVector) = vec
ensure_dense(vec::SparseVector) = collect(vec)

unsafe_free!(x::AbstractArray) = nothing
unsafe_free!(x::AbstractArray) = nothing


# fix for https://github.com/jonniedie/ComponentArrays.jl/issues/193
function Base.reshape(a::Array{T,M}, dims::Tuple{}) where {T,M}
throw_dmrsa(dims, len) =
throw(DimensionMismatch("new dimensions $(dims) must be consistent with array size $len"))

if prod(dims) != length(a)
throw_dmrsa(dims, length(a))
end
Base.isbitsunion(T) && return ReshapedArray(a, dims, ())
if 0 == M && dims == size(a)
return a
end
ccall(:jl_reshape_array, Array{T,0}, (Any, Any, Any), Array{T,0}, a, dims)
end

0 comments on commit 14c6c32

Please sign in to comment.