Skip to content

Commit

Permalink
Revert "memoize buffer needed for ldiv! plans"
Browse files Browse the repository at this point in the history
possibly causes some not-yet-understood errors for MUSE implicit-diff

This reverts commit 7a6a1b2.
  • Loading branch information
marius311 committed Mar 20, 2023
1 parent e31e2ff commit 92c731a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 33 deletions.
14 changes: 0 additions & 14 deletions src/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,6 @@ end
Base.view(arr::CuArray{T,2}, I, J, K, ::typeof(..)) where {T} = view(arr, I, J, K)
Base.view(arr::CuArray{T,3}, I, J, K, ::typeof(..)) where {T} = view(arr, I, J, K)

# CUFFT destroys the input array for irfft so a copy is needed but
# CUDA.jl's allocates a new array for this. here do it instead into
# some memoized memory and avoid allocations
function ldiv_safe!(dst, plan::CUDA.CUFFT.rCuFFTPlan{CUDA.CUFFT.cufftReal}, src)
inv_plan = inv(plan)
CUDA.CUFFT.cufftExecC2R(inv_plan.p, copy_into_irfft_cache(src), dst)
LinearAlgebra.lmul!(inv_plan.scale, dst)
end
function ldiv_safe!(dst, plan::CUDA.CUFFT.rCuFFTPlan{CUDA.CUFFT.cufftDoubleReal}, src)
inv_plan = inv(plan)
CUDA.CUFFT.cufftExecZ2D(inv_plan.p, copy_into_irfft_cache(src), dst)
LinearAlgebra.lmul!(inv_plan.scale, dst)
end


## ForwardDiff through FFTs
# these definitions needed bc the CUDA.jl definitions supersede the
Expand Down
23 changes: 4 additions & 19 deletions src/util_fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function m_irfft(arr::AbstractArray{T,N}, d, dims) where {T,N}
m_plan_rfft(_fft_arr_type(arr){real(T),N}, dims, output_size...) \ arr
end
m_rfft!(dst, arr::AbstractArray{T,N}, dims) where {T<:Real,N} = mul!(dst, m_plan_rfft(_fft_arr_type(arr){T,N}, dims, size(arr)...), arr)
m_irfft!(dst, arr::AbstractArray{T,N}, dims) where {T,N} = ldiv_safe!(dst, m_plan_rfft(_fft_arr_type(arr){real(T),N}, dims, size(dst)...), arr)
m_irfft!(dst, arr::AbstractArray{T,N}, dims) where {T,N} = ldiv!(dst, m_plan_rfft(_fft_arr_type(arr){real(T),N}, dims, size(dst)...), copy_if_fftw(arr))
m_fft(arr::AbstractArray{T,N}, dims) where {T,N} = m_plan_fft(_fft_arr_type(arr){complex(T),N}, dims, size(arr)...) * arr
m_ifft(arr::AbstractArray{T,N}, dims) where {T,N} = m_plan_fft(_fft_arr_type(arr){complex(T),N}, dims, size(arr)...) \ arr
m_fft!(dst, arr::AbstractArray{T,N}, dims) where {T,N} = mul!(dst, m_plan_fft(_fft_arr_type(arr){complex(T),N}, dims, size(arr)...), complex(arr))
Expand All @@ -38,24 +38,9 @@ end
plan_fft(A(undef, sz...), dims; (A <: Array ? (timelimit=FFTW_TIMELIMIT,) : ())...)
end
Zygote.@nograd m_plan_fft, m_plan_rfft


# FFTW and CUFFT (but not MKL) destroy the input array for inverse
# real FFTs, so we need a copy. see
# https://github.com/JuliaMath/FFTW.jl/issues/158. do the copy into
# some memoized memory to avoid allocation.
copy_into_irfft_cache(arr) = copy!(irfft_cache(typeof(arr), size(arr)), arr)
@memoize irfft_cache(Arr, sz) = Arr(undef, sz...)
function ldiv_safe!(dst, plan::FFTW.rFFTWPlan, src)
if FFTW.fftw_provider == "fftw"
ldiv!(dst, plan, copy_into_irfft_cache(src))
else
ldiv!(dst, plan, src)
end
end



# FFTW (but not MKL) destroys the input array for inplace inverse real
# FFTs, so we need a copy. see https://github.com/JuliaMath/FFTW.jl/issues/158
copy_if_fftw(x) = (x isa Array && FFTW.fftw_provider == "fftw") ? copy(x) : x


"""
Expand Down

0 comments on commit 92c731a

Please sign in to comment.