Skip to content

Commit

Permalink
Improve array indexing (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed May 23, 2022
1 parent ab0d6fa commit 5b2845a
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 38 deletions.
3 changes: 1 addition & 2 deletions src/functions/indAffineIterative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ end
function prox!(y, f::IndAffineIterative{M, V}, x, gamma) where {M, V}
# Von Neumann's alternating projections
R = real(eltype(x))
m = size(f.A, 1)
y .= x
for k = 1:1000
maxres = R(0)
for i = 1:m
for i in eachindex(f.b)
resi = (f.b[i] - dot(f.A[i,:], y))
y .= y + resi*f.A[i,:] # no need to divide: rows of A are normalized
absresi = resi > 0 ? resi : -resi
Expand Down
26 changes: 12 additions & 14 deletions src/functions/indBallL0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,28 @@ function (f::IndBallL0)(x)
return R(0)
end

function _get_top_k_abs_indices(x::AbstractVector, k)
range = firstindex(x):(firstindex(x) + k - 1)
return partialsortperm(x, range, by=abs, rev=true)
end

_get_top_k_abs_indices(x, k) = _get_top_k_abs_indices(x[:], k)

function prox!(y, f::IndBallL0, x, gamma)
T = eltype(x)
p = []
if ndims(x) == 1
p = partialsortperm(x, 1:f.r, by=abs, rev=true)
else
p = partialsortperm(x[:], 1:f.r, by=abs, rev=true)
end
sort!(p)
idx = 1
for i = 1:length(p)
y[idx:p[i]-1] .= T(0)
p = _get_top_k_abs_indices(x, f.r)
y .= T(0)
for i in eachindex(p)
y[p[i]] = x[p[i]]
idx = p[i]+1
end
y[idx:end] .= T(0)
return real(T)(0)
end

function prox_naive(f::IndBallL0, x, gamma)
T = eltype(x)
p = sortperm(abs.(x)[:], rev=true)
y = similar(x)
y[p[1:f.r]] .= x[p[1:f.r]]
y[p[f.r+1:end]] .= T(0)
y[p[begin:begin+f.r-1]] .= x[p[begin:begin+f.r-1]]
y[p[begin+f.r:end]] .= T(0)
return y, real(T)(0)
end
10 changes: 5 additions & 5 deletions src/functions/indPSD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function (f::IndPSD)(x::AbstractVector{Float64})
f.scaling && scale_diagonal!(y, sqrt(2))

Z = dspev!(:N, :L, y)
for i in 1:length(Z)
for i in eachindex(Z)
# Do we allow for some tolerance here?
if Z[i] <= -1e-14
return +Inf
Expand All @@ -118,7 +118,7 @@ function prox!(y::AbstractVector{Float64}, f::IndPSD, x::AbstractVector{Float64}
# Now let M = Z*diagm(W)*Z'
M = M*Z'
n = length(W)
k = 1
k = firstindex(y)
# Store lower diagonal of M in y
for j in 1:n, i in j:n
y[k] = M[i,j]
Expand All @@ -135,8 +135,8 @@ function prox_naive(f::IndPSD, x::AbstractVector{Float64}, gamma)
# Formula for size of matrix
n = Int(sqrt(1/4+2*length(x))-1/2)
X = Matrix{Float64}(undef, n, n)
k = 1
# Store y in M
k = firstindex(x)
# Store x in X
for j = 1:n, i = j:n
# Lower half
X[i,j] = x[k]
Expand Down Expand Up @@ -164,7 +164,7 @@ function prox_naive(f::IndPSD, x::AbstractVector{Float64}, gamma)
end

y = similar(x)
k = 1
k = firstindex(y)
# Store Lower half of X in y
for j = 1:n, i = j:n
y[k] = X[i,j]
Expand Down
20 changes: 10 additions & 10 deletions src/functions/normL21.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ function (f::NormL21)(X)
nslice = R(0)
n21X = R(0)
if f.dim == 1
for j = 1:size(X, 2)
for j in axes(X, 2)
nslice = R(0)
for i = 1:size(X, 1)
for i in axes(X, 1)
nslice += abs(X[i, j])^2
end
n21X += sqrt(nslice)
end
elseif f.dim == 2
for i = 1:size(X, 1)
for i in axes(X, 1)
nslice = R(0)
for j = 1:size(X, 2)
for j in axes(X, 2)
nslice += abs(X[i, j])^2
end
n21X += sqrt(nslice)
Expand All @@ -58,29 +58,29 @@ function prox!(Y, f::NormL21, X, gamma)
nslice = R(0)
n21X = R(0)
if f.dim == 1
for j = 1:size(X, 2)
for j in axes(X, 2)
nslice = R(0)
for i = 1:size(X, 1)
for i in axes(X, 1)
nslice += abs(X[i, j])^2
end
nslice = sqrt(nslice)
scal = 1 - gl / nslice
scal = scal <= 0 ? R(0) : scal
for i = 1:size(X, 1)
for i in axes(X, 1)
Y[i, j] = scal * X[i, j]
end
n21X += scal * nslice
end
elseif f.dim == 2
for i = 1:size(X, 1)
for i in axes(X, 1)
nslice = R(0)
for j = 1:size(X, 2)
for j in axes(X, 2)
nslice += abs(X[i, j])^2
end
nslice = sqrt(nslice)
scal = 1-gl/nslice
scal = scal <= 0 ? R(0) : scal
for j = 1:size(X, 2)
for j in axes(X, 2)
Y[i, j] = scal * X[i, j]
end
n21X += scal * nslice
Expand Down
4 changes: 2 additions & 2 deletions test/test_calculus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ stuff = [
)
]

@testset "$i" for i = 1:length(stuff)
@testset "$i" for i in eachindex(stuff)
f = stuff[i]["funcs"][1]
g = stuff[i]["funcs"][2]

for j = 1:length(stuff[i]["args"])
for j in eachindex(stuff[i]["args"])
x = stuff[i]["args"][j]
gamma = stuff[i]["gammas"][j]

Expand Down
2 changes: 1 addition & 1 deletion test/test_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ stuff = [
),
]

for i = 1:length(stuff)
for i in eachindex(stuff)

f = stuff[i]["f"]
x = stuff[i]["x"]
Expand Down
6 changes: 3 additions & 3 deletions test/test_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ stuff = [
),
]

for i = 1:length(stuff)
for i in eachindex(stuff)
constr = stuff[i]["constr"]

if haskey(stuff[i], "wrong")
for j = 1:length(stuff[i]["wrong"])
for j in eachindex(stuff[i]["wrong"])
wrong = stuff[i]["wrong"][j]
@test_throws ErrorException constr(wrong...)
end
end

for j = 1:length(stuff[i]["params"])
for j in eachindex(stuff[i]["params"])
params = stuff[i]["params"][j]
x = stuff[i]["args"][j]
f = constr(params...)
Expand Down
2 changes: 1 addition & 1 deletion test/test_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ stuff = [
)
]

@testset "$(i)" for i = 1:length(stuff)
@testset "$(i)" for i in eachindex(stuff)

f = stuff[i]["f"]
x = stuff[i]["x"]
Expand Down

2 comments on commit 5b2845a

@lostella
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61070

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.15.3 -m "<description of version>" 5b2845a2d34f61a4e002ee28801ae1de5da8639f
git push origin v0.15.3

Please sign in to comment.