-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request] Integration with equinox #52
Comments
Thank you for raising this @AntonyJia159! I would love to support Equinox -- although I have to say that I never worked with it and as you correctly pointed out evosax so far only support neural network libraries which use weight pytrees more explicitly. |
FYI @RobertTLange @AntonyJia159 - Equinox works ok out the box! You just need to vmap the evaluation of the networks too: import jax
import jax.numpy as jnp
from evosax import CMA_ES, ParameterReshaper
import equinox as eqx
def fitness(net, input, target):
output = net(input)
return jnp.mean((output - target) ** 2)
if __name__ == "__main__":
# Set up single equinox network
rng = jax.random.PRNGKey(0)
fixed_input = jax.random.uniform(key=rng, shape=(16,))
fixed_output = 18.0
network = eqx.nn.Linear(16, 1, key=rng)
print("Example single random fitness ", fitness(network, fixed_input, fixed_output))
# Set up for evosax
reshaper = ParameterReshaper(network)
fitness_many = jax.vmap(fitness, in_axes=(0, None, None)) # Adjust 'None's if working on batched rather than fixed data
# Instantiate the search strategy
strategy = CMA_ES(popsize=20, num_dims=reshaper.total_params, elite_ratio=0.5)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)
# Run ask-eval-tell loop
for t in range(10):
rng, rng_gen, rng_eval = jax.random.split(rng, 3)
candidate_params, state = strategy.ask(rng_gen, state, es_params)
candidate_networks = reshaper.reshape(candidate_params)
fitnesses = fitness_many(candidate_networks, fixed_input, fixed_output)
print(jnp.min(fitnesses))
state = strategy.tell(candidate_params, fitnesses, state, es_params)
# Get best overall population member & its fitness
print("best fitness:", state.best_fitness)
print("best params:", state.best_member) Gives:
Hopefully this is helpful to you guys / others! |
Equinox is a minimalistic jax nn library. Currently, evosax's "parameter_reshaper" doesn't seem to support it, as the returned modules will have an population (batch) dimension in their weights:
Something like
rng = jax.random.PRNGKey(0) network = eqx.nn.GRUCell(16,20,key=rng) param_reshaper = ParameterReshaper(network)
gets:
'GRUCell(
weight_ih=f32[1,60,16],
weight_hh=f32[1,60,20],
bias=f32[1,60],
bias_n=f32[1,20],
input_size=16,
hidden_size=20,
use_bias=True
)
For a population of size one, The correct form should be multiple cells like this:
GRUCell(weight_ih=f32[60,16],
weight_hh=f32[60,20],
bias=f32[60],
bias_n=f32[20],
input_size=16,
hidden_size=20,
use_bias=True
)`
It's possible that there was just something I didn't figure out about the reshaper thing. If that's the case, please kindly inform me of the issue.
Evosax is an awesome project, and more integration into the whole ecosystem would make it even better : )
The text was updated successfully, but these errors were encountered: