-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request] Integration with optax #51
Comments
Hi @carlosgmartin -- thank you for raising this and please excuse the late response! I have actually been thinking about this for a second. For now, we have opted for the arguably awkward re-implementation of most common optimizers. This had a couple of reasons:
In general, I am very open to switching to optax in the long run, especially since I am also personally interested in trying out some of the learned gradient-based optimizers in the context of ES (e.g. LION or VeLO), which already all have optax support. Furthermore, it may make population-parallelism (similar to data-parallelism in multi-device settings) a lot more smooth. I have started playing around with some of these ideas for OpenAI-ES here. Let me know if you would be interested in supporting me in this endeavor. My bucket list right now is pretty large -- so at the moment I can't guarantee that this will happen within the next week(s). Cheers and again thank you, Rob |
@RobertTLange Perhaps a good first step could be to reimplement https:/RobertTLange/evosax/blob/main/evosax/core/optimizer.py internally in terms of optax (perhaps via an OptaxWrapper class), and gradually change the API "outwards" from there. What do you think? |
That indeed sounds like a great start ;) The question is how we ultimately want to parse the optax optimizer to the finite-difference style strategies (OpenAI-ES, PGPE, ARS, ASEBO, etc.). Should this be a string as done right now? Or whether we directly give a optimizer gradient transform fn at strategy initialization. Maybe for the start we can support both? |
It seems to me that, at least in the long run, the optax optimizer should be passed in directly. I see a few advantages to this approach:
What do you think? |
@RobertTLange Here's an example of the simplified interface and implementation I have in mind: import argparse
import sys
import evosax
import jax
import optax
from jax import lax, numpy as jnp, random
from jax.flatten_util import ravel_pytree
class OpenAIES:
def __init__(self, loss_fn, scale, batch_size, optimizer):
assert batch_size % 2 == 0, "batch_size must be even"
self.loss_fn = loss_fn
self.optimizer = optimizer
self.pairs = batch_size // 2
self.scale = scale
def init(self, params, key):
return params, self.optimizer.init(params)
def grads(self, params, key):
x, unravel = ravel_pytree(params)
keys = random.split(key, 1 + self.pairs)
z = random.normal(keys[0], [self.pairs, x.size])
u = z * self.scale
yp = jax.vmap(self.loss_fn)(jax.vmap(unravel)(x + u), keys[1:])
ym = jax.vmap(self.loss_fn)(jax.vmap(unravel)(x - u), keys[1:])
g = (yp - ym) @ z / (len(z) * 2 * self.scale)
grads = unravel(g)
return grads
def step(self, state, key):
params, opt_state = state
grads = self.grads(params, key)
updates, opt_state = self.optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return (params, opt_state), None
def params(self, state):
params, opt_state = state
return params
class EvosaxWrapper:
def __init__(self, loss_fn, strategy):
self.loss_fn = loss_fn
self.strategy = strategy
def init(self, params, key):
return self.strategy.initialize(key, init_mean=params)
def step(self, state, key):
keys = jax.random.split(key, 1 + self.strategy.popsize)
x, state = self.strategy.ask(keys[0], state)
# jnp.concatenate([keys[1:], keys[1:]])
y = jax.vmap(self.loss_fn)(x, keys[1:])
n_state = self.strategy.tell(x, y.astype(float), state)
return n_state, None
def params(self, state):
return self.strategy.param_reshaper.reshape_single(state.mean)
def parse_args(argv):
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale", type=float, default=1e-1)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-2)
parser.add_argument("--steps", type=int, default=10**4)
return parser.parse_args(argv)
def main(argv):
args = parse_args(argv)
def loss_fn(params, key):
return params["a"] @ params["a"] + random.normal(key) * 0.1
params = {"a": jnp.ones(4)}
key = random.PRNGKey(args.seed)
for strategy in [
EvosaxWrapper(
loss_fn,
evosax.OpenES(
popsize=args.batch_size,
pholder_params=params,
opt_name="adam",
lrate_init=args.lr,
lrate_limit=0,
sigma_init=args.scale,
sigma_limit=0,
),
),
OpenAIES(loss_fn, args.scale, args.batch_size, optax.adam(args.lr)),
]:
keys = random.split(key, 1 + args.steps)
state = strategy.init(params, keys[0])
state, _ = lax.scan(strategy.step, state, keys[1:])
print(strategy.params(state)["a"])
if __name__ == "__main__":
main(sys.argv[1:]) Output:
|
optax is the most popular JAX library for optimizers. Feature request: Let users pass an optax.GradientTransformation to strategy constructors, rather than a combination of
opt_name
,lrate_init
, etc. This has a few advantages:Thank you for creating this library. Look forward to hearing your thoughts.
The text was updated successfully, but these errors were encountered: