Skip to content

Commit

Permalink
[optim] Fix: wrong ASGD implementation (#126375)
Browse files Browse the repository at this point in the history
This PR is based on #125440, additionally merging the latest main branch and fixing the lint failures from #126361.

Pull Request resolved: #126375
Approved by: https://github.com/janeyx99
  • Loading branch information
david20571015 authored and pytorchmergebot committed May 17, 2024
1 parent 078e530 commit 7e166e8
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 25 deletions.
10 changes: 9 additions & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,16 @@ def _compare_between(self, inputs, models, optimizers, assert_eq_kwargs=None, as
for input, model, optimizer in zip(inputs, models, optimizers):
optimizer.zero_grad()

if i == 3:
# Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
# is same as the step in 'forloop'.
model[2].requires_grad_(False)
if i == 5:
# Unfreeze the layer after 2 iters.
model[2].requires_grad_(True)

# Test that step behaves as expected (a no-op) when grads are set to None
if i != 3:
if i != 2:
output = model(input)
loss = output.sum()
loss.backward()
Expand Down
10 changes: 10 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
corresponding_real_dtype,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
FloatLike,
IntLike,
make_contiguous_strides_for,
Number,
Expand Down Expand Up @@ -3286,6 +3287,15 @@ def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
return


@register_meta([aten._foreach_pow_.Scalar])
def meta__foreach_pow__scalar(self, exponent):
torch._check(
isinstance(exponent, FloatLike),
lambda: f"exponent must be a float but got {type(exponent)}",
)
return


@register_meta([aten._foreach_pow.ScalarAndTensor])
def meta__foreach_pow_scalar_and_tensor(self, exponent):
# Only foreach_pow has a ScalarAndTensor method and needs special
Expand Down
37 changes: 13 additions & 24 deletions torch/optim/asgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@
__all__ = ["ASGD", "asgd"]


def _to_tensor(x, device=None):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, device=device)

return x


class ASGD(Optimizer):
def __init__(
self,
Expand Down Expand Up @@ -264,9 +257,9 @@ def _single_tensor_asgd(
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else:
step = _get_value(step_t)
new_eta = _to_tensor(lr / ((1 + lambd * lr * step) ** alpha))
new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
eta.copy_(new_eta)
new_mu = _to_tensor(1 / max(1, step - t0))
new_mu = torch.as_tensor(1 / max(1, step - t0))
mu.copy_(new_mu)


Expand Down Expand Up @@ -381,27 +374,23 @@ def _multi_tensor_asgd(
torch._foreach_copy_(grouped_mus, new_mus)
del new_mus

# update eta = lr / (1 + lambd * lr * step^alpha)
new_etas = torch._foreach_pow(grouped_state_steps, alpha)
torch._foreach_mul_(new_etas, lambd)
# update eta = lr / ((1 + lambd * lr * step)^alpha)
new_etas = torch._foreach_mul(grouped_state_steps, lambd)
torch._foreach_mul_(new_etas, lr)
torch._foreach_add_(new_etas, 1)
torch._foreach_pow_(new_etas, alpha)
torch._foreach_reciprocal_(new_etas)
torch._foreach_mul_(new_etas, lr)
torch._foreach_copy_(grouped_etas, new_etas)
else:
step = grouped_state_steps[0].item()
new_etas = []
new_mus = []

for i in range(len(grouped_mus)):
new_eta = _to_tensor(
lr / (1 + lambd * lr * step**alpha), device=device
)
new_etas.append(new_eta)
new_mu = _to_tensor(1 / max(1, step - t0), device=device)
new_mus.append(new_mu)

new_etas = [
torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
for step in grouped_state_steps
]
new_mus = [
torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
for step in grouped_state_steps
]
torch._foreach_copy_(grouped_etas, new_etas)
torch._foreach_copy_(grouped_mus, new_mus)

Expand Down
8 changes: 8 additions & 0 deletions torch/testing/_internal/common_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def optim_inputs_func_asgd(device, dtype=None):
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lambd": 0.1}, desc="non-default lambd"),
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
Expand Down Expand Up @@ -1450,6 +1451,13 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
"TestOptimRenewed",
"test_defaults_changed_to_foreach",
),
DecorateInfo(
unittest.skip(
"ASGD internally changes the weights even with zero grad"
),
"TestOptimRenewed",
"test_step_is_noop_for_zero_grads",
),
),
),
OptimizerInfo(
Expand Down

0 comments on commit 7e166e8

Please sign in to comment.