Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any plan for including multidimensional linear operators? #276

Open
ehgus opened this issue Jun 4, 2023 · 2 comments
Open

Any plan for including multidimensional linear operators? #276

ehgus opened this issue Jun 4, 2023 · 2 comments

Comments

@ehgus
Copy link

ehgus commented Jun 4, 2023

Hello, do you have a plan to include array-to-arrat feature in LinearOperator?
The current LinearOperator only support vector-to-vector.
This is not enough because I need to do redundant reshaping before applying multidimensional linear functions such as fft.

I hope it can explicitly support such linear functions.

@dpo
Copy link
Member

dpo commented Jun 4, 2023

Do you mean something like

op = LinearOperator()
A = rand(m, n)
B = op * A

If so, you should already be able to do it, but B is another linear operator. You can materialize it with Matrix(B).

@ehgus
Copy link
Author

ehgus commented Jun 5, 2023

What I mean is like

sz = (6,6,6)
op = LinearOperator(Float64, sz, sz, ...) # input: array with size (6,6,6) -> output: array with size (6,6,6) 
A = rand(sz)
B = op * A # return array with size (6,6,6) 

There are list of linear functions for tensor available in Julia.

I have used a trick to use such functions. For example, when doing 3-dim FFT (input: 3-dim array, output: 3-dim array), I give a vectorized array to LinearOperator and then convert it to the original shape before FFT.

using LinearOperators
using FFTW

sz = (6,6,6)
X = randn(ComplexF64, sz) |> vec;

function reshape_fft!(res, v, α, β)
    v = reshape(v,sz)
    if β == 0
        res .= α .* vec(fft(v))
    else
        res .= α .* vec(fft(v)) .+ β .* res
    end
end

function reshape_ifft!(res, v, α, β)
    v = reshape(v,sz)
    if β == 0
        res .= α .* vec(ifft(v))
    else
        res .= α .* vec(ifft(v)) .+ β .* res
    end
end

dft = LinearOperator(ComplexF64, prod(sz), prod(sz), false,false, reshape_fft!, nothing, reshape_ifft!)

rst =  reshape(dft * X, sz)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants