Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First step of GradientDescent optimizer is a no-op #82

Open
eringrant opened this issue Sep 5, 2024 · 1 comment
Open

First step of GradientDescent optimizer is a no-op #82

eringrant opened this issue Sep 5, 2024 · 1 comment

Comments

@eringrant
Copy link

eringrant commented Sep 5, 2024

It seems like the first call to step of the GradientDescent optimizer doesn't perform the step operation. I didn't check if this occurs for other optimizers or do other digging, but can do so if this is not expected behavior and the cause is not immediate. Here is a MWE:

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--use_optax", action="store_true")
args = parser.parse_args()

if args.use_optax:
  import optax
  optimizer = optx.OptaxMinimiser(optax.sgd(1e-1), rtol=1e-4, atol=1e-4)
else:
  optimizer = optx.GradientDescent(learning_rate=1e-1, rtol=1e-4, atol=1e-4)

N = K = 8

k1, k2 = jax.random.split(jax.random.PRNGKey(0))
w_star = jax.random.normal(k1, (K, N))
w_hat = jax.random.normal(k2, (K, N))

x = jnp.linspace(0, 1, N)[None, ...]
y = jnp.dot(w_star, x.T)


def loss(w, _):
  return jnp.mean((jnp.dot(w, x.T) - y) ** 2), None


options = None
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
tags = frozenset()

init = eqx.Partial(
  optimizer.init,
  args=None,
  fn=loss,
  options=options,
  f_struct=f_struct,
  aux_struct=aux_struct,
  tags=tags,
)
step = eqx.Partial(
  optimizer.step,
  args=None,
  fn=loss,
  options=options,
  tags=tags,
)

state = init(y=w_hat)
initial_loss = loss(w_hat, None)[0]
print(f"t = 0 | loss = {initial_loss}.")

w_hat, state, _ = step(y=w_hat, state=state)
one_step_loss = loss(w_hat, None)[0]
print(f"t = 1 | loss = {one_step_loss}.")

w_hat, state, _ = step(y=w_hat, state=state)
two_step_loss = loss(w_hat, None)[0]
print(f"t = 2 | loss = {two_step_loss}.")

if initial_loss == one_step_loss:
  raise ValueError("Loss did not decrease after one step of optimization.")

Running with GradientDescent gives:

$ python test.py
t = 0 | loss = 2.189293384552002.
t = 1 | loss = 2.189293384552002.
t = 2 | loss = 1.8877067565917969.
Traceback (most recent call last):
  File ".../test.py", line 68, in <module>
    raise ValueError("Loss did not decrease after one step of optimization.")
ValueError: Loss did not decrease after one step of optimization.

cf. OptaxMinimiser(optax.sgd(...), ...):

$ python test.py --use_optax
t = 0 | loss = 2.189293384552002.
t = 1 | loss = 1.8877067565917969.
t = 2 | loss = 1.6276657581329346.
@patrick-kidger
Copy link
Owner

This is expected... but admittedly maybe not great design.

The relevant code is here:

def step(
self,
fn: Fn[Y, Scalar, Aux],
y: Y,
args: PyTree,
options: dict[str, Any],
state: _GradientDescentState,
tags: frozenset[object],
) -> tuple[Y, _GradientDescentState, Aux]:
f_eval, lin_fn, aux_eval = jax.linearize(
lambda _y: fn(_y, args), state.y_eval, has_aux=True
)
step_size, accept, search_result, search_state = self.search.step(
state.first_step,
y,
state.y_eval,
state.f_info,
FunctionInfo.Eval(f_eval),
state.search_state,
)
def accepted(descent_state):
(grad,) = lin_to_grad(lin_fn, y)
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state)
y_diff = (state.y_eval**ω - y**ω).ω
f_diff = (f_eval**ω - state.f_info.f**ω).ω
terminate = cauchy_termination(
self.rtol, self.atol, self.norm, state.y_eval, y_diff, f_eval, f_diff
)
terminate = jnp.where(
state.first_step, jnp.array(False), terminate
) # Skip termination on first step
return state.y_eval, f_eval_info, aux_eval, descent_state, terminate
def rejected(descent_state):
return y, state.f_info, state.aux, descent_state, jnp.array(False)
y, f_info, aux, descent_state, terminate = filter_cond(
accept, accepted, rejected, state.descent_state
)
y_descent, descent_result = self.descent.step(step_size, descent_state)
y_eval = (y**ω + y_descent**ω).ω
result = RESULTS.where(
search_result == RESULTS.successful, descent_result, search_result
)
state = _GradientDescentState(
first_step=jnp.array(False),
y_eval=y_eval,
search_state=search_state,
f_info=f_info,
aux=aux,
descent_state=descent_state,
terminate=terminate,
result=result,
)
return y, state, aux

The way this works is that we actually treat general gradient methods, which typically start by picking a descent direction, and then performing a line search in that direction. Once the line search has found an acceptable point to stop, then this location is used to start a new line search.

In the case of GradientDescent, the line search is a single step of size corresponding the learning rate, and the result is always treated as acceptable. This means that the 'accepted' point is the start of the line search -- which is the previous iteration.

Off the top of my head I'm not sure how we'd change this. We might be able to tweak the logic in the above block of code to remove this off-by-one approach to things. (I'm open to suggestions on this one.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants