Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework regularized layers #73

Closed
wants to merge 1 commit into from
Closed

Rework regularized layers #73

wants to merge 1 commit into from

Conversation

BatyLeo
Copy link
Member

@BatyLeo BatyLeo commented Jun 29, 2023

  • Now regularized layers are unified under the Regularized struct, and not under the IsRegularized anymore (partially adresses Get rid of SimpleTraits? #68). Every regularized layer is now a particular insance of Regularized
  • Specific constructors for SparseArgmax, SoftArgmax, and RegularizedFrankWolfe
  • Now we can also use Regularized with a custom optimizer (adresses Other solvers than FW for RegularizedGeneric #62)

TODO:

  • cleanup and docstrings
  • test Regularized with a custom optimizer

  - Only one struct named `Regularized`, every regularized layer is a particular case of it
  - Specific constructors for `SparseArgmax`, `SoftArgmax`, and `RegularizedFrankWolfe`
  - Now we can also use `Regularized` with a custom optimizer (we may need to test this feature)
@codecov-commenter
Copy link

codecov-commenter commented Jun 29, 2023

Codecov Report

Patch coverage: 77.77% and project coverage change: -0.25 ⚠️

Comparison is base (b9f84b9) 80.57% compared to head (3fc2bc1) 80.33%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #73      +/-   ##
==========================================
- Coverage   80.57%   80.33%   -0.25%     
==========================================
  Files          19       20       +1     
  Lines         345      356      +11     
==========================================
+ Hits          278      286       +8     
- Misses         67       70       +3     
Impacted Files Coverage Δ
src/InferOpt.jl 100.00% <ø> (ø)
src/regularized/frank_wolfe_optimizer.jl 25.00% <25.00%> (ø)
src/regularized/regularized.jl 78.57% <78.57%> (ø)
ext/InferOptFrankWolfeExt.jl 100.00% <100.00%> (ø)
src/fenchel_young/fenchel_young.jl 88.00% <100.00%> (ø)
src/regularized/soft_argmax.jl 100.00% <100.00%> (ø)
src/regularized/sparse_argmax.jl 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@BatyLeo BatyLeo added the enhancement New feature or request label Jun 29, 2023
@BatyLeo BatyLeo linked an issue Jun 29, 2023 that may be closed by this pull request
@@ -32,30 +32,14 @@ Some values you can tune:

See the documentation of FrankWolfe.jl for details.
"""
struct RegularizedGeneric{M,RF,RG,FWK}
maximizer::M
struct FrankWolfeOptimizer{M,RF,RG,FWK}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather call this FrankWolfeConcaveMaximizer

"""
struct Regularized{O,R}
Ω::R
optimizer::O
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather call this concave_maximizer to differentiate from (linear_)maximizer used elsewhere

TODO
"""
function RegularizedFrankWolfe(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs=NamedTuple())
# TODO : add a warning if DifferentiableFrankWolfe is not imported ?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

@@ -9,7 +9,7 @@ Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytop
Since this is a conditional dependency, you need to run `import DifferentiableFrankWolfe` before using `RegularizedGeneric`.

# Fields
- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
- `linear_maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should use linear_maximizer throughout InferOpt?

"""
optimizer: θ ⟼ argmax θᵀy - Ω(y)
"""
struct Regularized{O,R}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need the linear maximizer as a field for when the layer is called outside of training?
It would make sense to me to modify the behavior of Perturbed as well so that the standard forward pass just calls the naked linear maximizer

@@ -10,8 +10,12 @@ function soft_argmax(z::AbstractVector; kwargs...)
return s
end

@traitimpl IsRegularized{typeof(soft_argmax)}
# @traitimpl IsRegularized{typeof(soft_argmax)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the trash

@@ -10,11 +10,15 @@ function sparse_argmax(z::AbstractVector; kwargs...)
return p
end

@traitimpl IsRegularized{typeof(sparse_argmax)}
# @traitimpl IsRegularized{typeof(sparse_argmax)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the trash

@gdalle
Copy link
Collaborator

gdalle commented Jun 29, 2023

test Regularized with a custom optimizer

What do you have in mind? I think we can use a basic QP solver from JuMP or write our own with FISTA

@gdalle gdalle marked this pull request as draft June 30, 2023 11:34
@gdalle gdalle closed this Jun 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Other solvers than FW for RegularizedGeneric
3 participants