Skip to content

Commit

Permalink
Allow the bisection root finder to expand the interval
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar authored and patrick-kidger committed Aug 22, 2024
1 parent 1c47169 commit 1de4264
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 7 deletions.
89 changes: 82 additions & 7 deletions optimistix/_solver/bisection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools as ft
from collections.abc import Callable
from typing import Any, ClassVar, Literal, Union

Expand All @@ -18,6 +19,63 @@ class _BisectionState(eqx.Module, strict=True):
error: Float[Array, ""]


class _ExpansionCarry(eqx.Module, strict=True):
lower: Scalar
upper: Scalar


def _interval_contains_root(
carry: _ExpansionCarry,
/,
*,
need_positive: Bool[Array, ""],
expand_upper: Bool[Array, ""],
fn: Fn[Scalar, Scalar, Aux],
args: Any,
) -> Bool[Array, ""]:
new_boundary = jnp.where(expand_upper, carry.upper, carry.lower)
carry_val, _ = fn(new_boundary, args)
return need_positive ^ (carry_val > 0.0)


def _expand_interval(
carry: _ExpansionCarry,
/,
*,
expand_upper: Bool[Array, ""],
) -> _ExpansionCarry:
new_domain = 2.0 * (carry.upper - carry.lower)
new_lower = jnp.where(expand_upper, carry.upper, carry.lower - new_domain)
new_upper = jnp.where(expand_upper, carry.upper + new_domain, carry.lower)
return _ExpansionCarry(new_lower, new_upper)


def _expand_interval_repeatedly(
lower: Scalar,
upper: Scalar,
*,
upper_val: Scalar,
lower_val: Scalar,
need_positive: Bool[Array, ""],
fn: Fn[Scalar, Scalar, Any],
args: PyTree,
) -> tuple[Scalar, Scalar]:
initial_interval = _ExpansionCarry(lower, upper)
expand_upper = need_positive ^ (upper_val < lower_val)
cond_fun = ft.partial(
_interval_contains_root,
need_positive=need_positive,
expand_upper=expand_upper,
fn=fn,
args=args,
)
body_fun = ft.partial(_expand_interval, expand_upper=expand_upper)
final_interval = jax.lax.while_loop(cond_fun, body_fun, initial_interval)
lower = final_interval.lower
upper = final_interval.upper
return lower, upper


class Bisection(AbstractRootFinder[Scalar, Scalar, Aux, _BisectionState], strict=True):
"""The bisection method of root finding. This may only be used with functions
`R->R`, i.e. functions with scalar input and scalar output.
Expand All @@ -34,11 +92,16 @@ class Bisection(AbstractRootFinder[Scalar, Scalar, Aux, _BisectionState], strict
sign of the evaluated function at the midpoint of the interval, and then keeping
whichever half contains the root. This is then repeated. The iteration stops once
the interval is sufficiently small.
If expand_if_necessary and detect are true, the initial interval will be expanded
if it doesn't contain the the root. This expansion assumes that the function is
monotonic.
"""

rtol: float
atol: float
flip: Union[bool, Literal["detect"]] = "detect"
expand_if_necessary: bool = False
# All norms are the same for scalars.
norm: ClassVar[Callable[[PyTree], Scalar]] = jnp.abs

Expand Down Expand Up @@ -71,14 +134,26 @@ def init(
elif self.flip == "detect":
lower_val, _ = fn(lower, args)
upper_val, _ = fn(upper, args)
lower_neg = lower_val < 0
upper_neg = upper_val < 0
flip = lower_val > upper_val
flip = eqx.error_if(
flip,
lower_neg == upper_neg,
msg="The root is not contained in [lower, upper]",
)
if self.expand_if_necessary:
lower, upper = _expand_interval_repeatedly(
lower,
upper,
upper_val=upper_val,
lower_val=lower_val,
need_positive=lower_val < 0.0,
fn=fn,
args=args,
)
else:
lower_neg = lower_val < 0
upper_neg = upper_val < 0
root_not_contained = lower_neg == upper_neg
flip = eqx.error_if(
flip,
root_not_contained,
msg="The root is not contained in [lower, upper]",
)
else:
raise ValueError("`flip` may only be True, False, or 'detect'.")
return _BisectionState(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_root_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,24 @@ def test_bisection_flip():
assert jnp.allclose(0, sol.value, atol=1e-3)


@pytest.mark.parametrize(
"fn", [lambda x, _: x, lambda x, _: -x], ids=["positive", "negative"]
)
@pytest.mark.parametrize(
("lower", "upper"),
[(lower, upper) for lower in (-3, 3) for upper in (-4, -2, 2, 4)],
)
def test_bisection_expansion(fn, lower, upper):
options = {"lower": lower, "upper": upper}
sol = optx.root_find(
fn,
optx.Bisection(rtol=0.0, atol=1e-4, expand_if_necessary=True),
100.0,
options=options,
)
assert jnp.allclose(0, sol.value, atol=1e-3)


@pytest.mark.parametrize(
"solver", [optx.Newton(rtol=1e-5, atol=1e-5), optx.Chord(rtol=1e-5, atol=1e-5)]
)
Expand Down

0 comments on commit 1de4264

Please sign in to comment.