Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Integration with equinox #52

Open
AntonyJia159 opened this issue Apr 25, 2023 · 2 comments
Open

[Feature request] Integration with equinox #52

AntonyJia159 opened this issue Apr 25, 2023 · 2 comments

Comments

@AntonyJia159
Copy link

AntonyJia159 commented Apr 25, 2023

Equinox is a minimalistic jax nn library. Currently, evosax's "parameter_reshaper" doesn't seem to support it, as the returned modules will have an population (batch) dimension in their weights:
Something like
rng = jax.random.PRNGKey(0) network = eqx.nn.GRUCell(16,20,key=rng) param_reshaper = ParameterReshaper(network)
gets:

'GRUCell(
weight_ih=f32[1,60,16],
weight_hh=f32[1,60,20],
bias=f32[1,60],
bias_n=f32[1,20],
input_size=16,
hidden_size=20,
use_bias=True
) For a population of size one, The correct form should be multiple cells like this:GRUCell(
weight_ih=f32[60,16],
weight_hh=f32[60,20],
bias=f32[60],
bias_n=f32[20],
input_size=16,
hidden_size=20,
use_bias=True
)`
It's possible that there was just something I didn't figure out about the reshaper thing. If that's the case, please kindly inform me of the issue.
Evosax is an awesome project, and more integration into the whole ecosystem would make it even better : )

@RobertTLange
Copy link
Owner

Thank you for raising this @AntonyJia159! I would love to support Equinox -- although I have to say that I never worked with it and as you correctly pointed out evosax so far only support neural network libraries which use weight pytrees more explicitly. ParameterReshaper requires a pytree as input in order to extract the correct shapes to reshape the flat parameter vectors used in ES. So far I couldn't figure out how to elegantly extract these from equinox modules -- but I am sure that this is possible. Furthermore, we might need a smooth way how to plug the proposed ES candidate weights back in for forward passes. Please let me know if this makes sense and if you have a proposal! Cheers, Rob

@hctomkins
Copy link

hctomkins commented Dec 20, 2023

FYI @RobertTLange @AntonyJia159 - Equinox works ok out the box! You just need to vmap the evaluation of the networks too:

import jax
import jax.numpy as jnp
from evosax import CMA_ES, ParameterReshaper
import equinox as eqx


def fitness(net, input, target):
    output = net(input)
    return jnp.mean((output - target) ** 2)


if __name__ == "__main__":
    # Set up single equinox network
    rng = jax.random.PRNGKey(0)
    fixed_input = jax.random.uniform(key=rng, shape=(16,))
    fixed_output = 18.0

    network = eqx.nn.Linear(16, 1, key=rng)
    print("Example single random fitness ", fitness(network, fixed_input, fixed_output))

    # Set up for evosax
    reshaper = ParameterReshaper(network)
    fitness_many = jax.vmap(fitness, in_axes=(0, None, None)) # Adjust 'None's if working on batched rather than fixed data

    # Instantiate the search strategy
    strategy = CMA_ES(popsize=20, num_dims=reshaper.total_params, elite_ratio=0.5)
    es_params = strategy.default_params
    state = strategy.initialize(rng, es_params)

    # Run ask-eval-tell loop
    for t in range(10):
        rng, rng_gen, rng_eval = jax.random.split(rng, 3)
        candidate_params, state = strategy.ask(rng_gen, state, es_params)
        candidate_networks = reshaper.reshape(candidate_params)
        fitnesses = fitness_many(candidate_networks, fixed_input, fixed_output)
        print(jnp.min(fitnesses))
        state = strategy.tell(candidate_params, fitnesses, state, es_params)

    # Get best overall population member & its fitness
    print("best fitness:", state.best_fitness)
    print("best params:", state.best_member)

Gives:

Example single fitness  311.59064
ParameterReshaper: 17 parameters detected for optimization.
173.77896
34.950092
19.661118
3.1265647
0.10655613
0.3538381
0.049168102
0.32669815
0.00019323104
0.019950658
best fitness: 0.00019323104
best params: [ 2.3055873   2.9635496   5.8396025   4.113425    0.47761774  2.2564042
  2.6945195  -1.506904   -0.23346913 -1.9476291   3.9967122  -1.9457947
 -3.7017791  -1.44961     0.69735444  1.0968039   5.4032116 ]

Hopefully this is helpful to you guys / others!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants