Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Integration with optax #51

Open
carlosgmartin opened this issue Apr 20, 2023 · 5 comments
Open

[Feature request] Integration with optax #51

carlosgmartin opened this issue Apr 20, 2023 · 5 comments

Comments

@carlosgmartin
Copy link

optax is the most popular JAX library for optimizers. Feature request: Let users pass an optax.GradientTransformation to strategy constructors, rather than a combination of opt_name, lrate_init, etc. This has a few advantages:

Thank you for creating this library. Look forward to hearing your thoughts.

@RobertTLange
Copy link
Owner

Hi @carlosgmartin -- thank you for raising this and please excuse the late response! I have actually been thinking about this for a second. For now, we have opted for the arguably awkward re-implementation of most common optimizers. This had a couple of reasons:

  • The opt_params can directly be stored in the es_params dataclass, which enables easy mapping across different learning rates/decay parameters. But to be entirely honest -- this doesn't happen so frequently.
  • All ES operate on flat parameter vectors and not pytrees, which simplifies some aspects of the optimizer.

In general, I am very open to switching to optax in the long run, especially since I am also personally interested in trying out some of the learned gradient-based optimizers in the context of ES (e.g. LION or VeLO), which already all have optax support. Furthermore, it may make population-parallelism (similar to data-parallelism in multi-device settings) a lot more smooth. I have started playing around with some of these ideas for OpenAI-ES here.

Let me know if you would be interested in supporting me in this endeavor. My bucket list right now is pretty large -- so at the moment I can't guarantee that this will happen within the next week(s). Cheers and again thank you, Rob

@carlosgmartin
Copy link
Author

carlosgmartin commented May 17, 2023

@RobertTLange Perhaps a good first step could be to reimplement https:/RobertTLange/evosax/blob/main/evosax/core/optimizer.py internally in terms of optax (perhaps via an OptaxWrapper class), and gradually change the API "outwards" from there. What do you think?

@RobertTLange
Copy link
Owner

RobertTLange commented May 19, 2023

That indeed sounds like a great start ;)

The question is how we ultimately want to parse the optax optimizer to the finite-difference style strategies (OpenAI-ES, PGPE, ARS, ASEBO, etc.). Should this be a string as done right now? Or whether we directly give a optimizer gradient transform fn at strategy initialization. Maybe for the start we can support both?

@carlosgmartin
Copy link
Author

carlosgmartin commented May 19, 2023

It seems to me that, at least in the long run, the optax optimizer should be passed in directly. I see a few advantages to this approach:

  • It saves evosax from having to "keep up" with optax's expanding list of optimizers.
  • It simplifies evosax's code.
  • It makes evosax more flexible, since users can pass in arbitrary compositions of optax gradient transforms.

What do you think?

@carlosgmartin
Copy link
Author

carlosgmartin commented Jul 20, 2023

@RobertTLange Here's an example of the simplified interface and implementation I have in mind:

import argparse
import sys

import evosax
import jax
import optax
from jax import lax, numpy as jnp, random
from jax.flatten_util import ravel_pytree


class OpenAIES:
    def __init__(self, loss_fn, scale, batch_size, optimizer):
        assert batch_size % 2 == 0, "batch_size must be even"
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.pairs = batch_size // 2
        self.scale = scale

    def init(self, params, key):
        return params, self.optimizer.init(params)

    def grads(self, params, key):
        x, unravel = ravel_pytree(params)
        keys = random.split(key, 1 + self.pairs)
        z = random.normal(keys[0], [self.pairs, x.size])
        u = z * self.scale
        yp = jax.vmap(self.loss_fn)(jax.vmap(unravel)(x + u), keys[1:])
        ym = jax.vmap(self.loss_fn)(jax.vmap(unravel)(x - u), keys[1:])
        g = (yp - ym) @ z / (len(z) * 2 * self.scale)
        grads = unravel(g)
        return grads

    def step(self, state, key):
        params, opt_state = state
        grads = self.grads(params, key)
        updates, opt_state = self.optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return (params, opt_state), None

    def params(self, state):
        params, opt_state = state
        return params


class EvosaxWrapper:
    def __init__(self, loss_fn, strategy):
        self.loss_fn = loss_fn
        self.strategy = strategy

    def init(self, params, key):
        return self.strategy.initialize(key, init_mean=params)

    def step(self, state, key):
        keys = jax.random.split(key, 1 + self.strategy.popsize)
        x, state = self.strategy.ask(keys[0], state)
        # jnp.concatenate([keys[1:], keys[1:]])
        y = jax.vmap(self.loss_fn)(x, keys[1:])
        n_state = self.strategy.tell(x, y.astype(float), state)
        return n_state, None

    def params(self, state):
        return self.strategy.param_reshaper.reshape_single(state.mean)


def parse_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--scale", type=float, default=1e-1)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--steps", type=int, default=10**4)
    return parser.parse_args(argv)


def main(argv):
    args = parse_args(argv)

    def loss_fn(params, key):
        return params["a"] @ params["a"] + random.normal(key) * 0.1

    params = {"a": jnp.ones(4)}
    key = random.PRNGKey(args.seed)

    for strategy in [
        EvosaxWrapper(
            loss_fn,
            evosax.OpenES(
                popsize=args.batch_size,
                pholder_params=params,
                opt_name="adam",
                lrate_init=args.lr,
                lrate_limit=0,
                sigma_init=args.scale,
                sigma_limit=0,
            ),
        ),
        OpenAIES(loss_fn, args.scale, args.batch_size, optax.adam(args.lr)),
    ]:
        keys = random.split(key, 1 + args.steps)
        state = strategy.init(params, keys[0])
        state, _ = lax.scan(strategy.step, state, keys[1:])
        print(strategy.params(state)["a"])


if __name__ == "__main__":
    main(sys.argv[1:])

Output:

ParameterReshaper: 4 parameters detected for optimization.
[ 0.01049289 -0.00886521 -0.00689528  0.0649689 ]
[-1.65249183e-08  9.41156397e-09  2.47366607e-08  1.30447395e-08]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants