Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Custom root fails for scalar inputs (jaxopt example) #222

Open
3 tasks done
marvinfriede opened this issue May 27, 2024 · 0 comments
Open
3 tasks done

[BUG] Custom root fails for scalar inputs (jaxopt example) #222

marvinfriede opened this issue May 27, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@marvinfriede
Copy link

marvinfriede commented May 27, 2024

Required prerequisites

What version of TorchOpt are you using?

0.7.3

System information

3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] linux
0.7.3 2.1.0 2.1.0

Problem description

I tried to recreate the jaxopt example for root finding with implicit differentiation using a very simple iterative solver. In jax, it just works. With torchopt, however, the gradient is zero for scalar inputs.

The problem stems from:

if maxiter is None:
size = sum(cat_shapes(b))
maxiter = 10 * size # copied from SciPy

Here, the size becomes zero for scalar b and the maxiter is wrongly set to 0. The same piece of code in jax produces a size of 1.

Reproducible example code

The Python snippets:

import torch
import torchopt


# https://jaxopt.github.io/stable/root_finding.html
def F(x, factor):
    return factor * x**3 - x - 2


@torchopt.diff.implicit.custom_root(
    F, argnums=(1,), solve=torchopt.linear_solve.solve_cg()
)
def custom_root_solver(init_x, factor):
    """Root solver using gradient descent."""
    maxiter = 100
    lr = 1e-1

    x = init_x
    for _ in range(maxiter):
        grad = F(x, factor)
        x = x - lr * grad

    return x


def wrapper(fac):
    return custom_root_solver(init_x, fac)


init_x = torch.tensor(1.0)
fac = torch.tensor(2.0, requires_grad=True)

root = wrapper(fac)
root_grad = torch.autograd.grad(root, fac)
print(root_grad)

Traceback

No response

Expected behavior

No response

Additional context

It works upon making the tensors 1D:

init_x = torch.tensor([1.0])
fac = torch.tensor([2.0], requires_grad=True)

It just works in jax.

import jax
import jax.numpy as jnp
from jaxopt.implicit_diff import custom_root
from jaxopt import Bisection

jax.config.update("jax_platform_name", "cpu")


def F(x, factor):
  return factor * x ** 3 - x - 2


def bisection_root_solver(init_x, factor):
  bisec = Bisection(optimality_fun=F, lower=1, upper=2)
  return bisec.run(factor=factor).params


@custom_root(F)
def custom_root_solver(init_x, factor):
    """Root solver using gradient descent."""
    maxiter = 100
    lr = 1e-1

    x = init_x
    for _ in range(maxiter):
        grad = F(x, factor)
        x = x - lr * grad

    return x


x_init = jnp.array(3.0)
fac = jnp.array(2.0)

print(custom_root_solver(x_init, fac))
print(bisection_root_solver(x_init, fac))

print(jax.grad(custom_root_solver, argnums=1)(x_init, fac))
print(jax.grad(bisection_root_solver, argnums=1)(x_init, fac))
@marvinfriede marvinfriede added the bug Something isn't working label May 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants