diff --git a/dev_tools/requirements/envs/dev.env.txt b/dev_tools/requirements/envs/dev.env.txt index 9d664981..90abb769 100644 --- a/dev_tools/requirements/envs/dev.env.txt +++ b/dev_tools/requirements/envs/dev.env.txt @@ -56,9 +56,9 @@ iniconfig==2.0.0 # via pytest isort==5.13.2 # via pylint -jax==0.4.23 +jax==0.4.31 # via -r deps/resource_estimates_runtime.txt -jaxlib==0.4.23 +jaxlib==0.4.31 # via -r deps/resource_estimates_runtime.txt jsonschema==4.21.0 # via nbformat diff --git a/dev_tools/requirements/envs/pytest-extra.env.txt b/dev_tools/requirements/envs/pytest-extra.env.txt index c361de5a..8f424118 100644 --- a/dev_tools/requirements/envs/pytest-extra.env.txt +++ b/dev_tools/requirements/envs/pytest-extra.env.txt @@ -74,11 +74,11 @@ iniconfig==2.0.0 # via # -c envs/dev.env.txt # pytest -jax==0.4.23 +jax==0.4.31 # via # -c envs/dev.env.txt # -r deps/resource_estimates_runtime.txt -jaxlib==0.4.23 +jaxlib==0.4.31 # via # -c envs/dev.env.txt # -r deps/resource_estimates_runtime.txt diff --git a/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py b/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py index 6c86bbef..45f9ce13 100644 --- a/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py +++ b/src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax.py @@ -35,11 +35,10 @@ from pyscf.pbc import scf from scipy.optimize import minimize -from jax.config import config +import jax -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) -import jax import jax.numpy as jnp import jax.typing as jnpt diff --git a/src/openfermion/resource_estimates/thc/factorize_thc.py b/src/openfermion/resource_estimates/thc/factorize_thc.py index 29838c7d..86c9c9ab 100644 --- a/src/openfermion/resource_estimates/thc/factorize_thc.py +++ b/src/openfermion/resource_estimates/thc/factorize_thc.py @@ -18,6 +18,7 @@ def thc_via_cp3( bfgs_maxiter=5000, random_start_thc=True, verify=False, + penalty_param=None, ): """ THC-CP3 performs an SVD decomposition of the eri matrix followed by a CP @@ -36,6 +37,7 @@ def thc_via_cp3( random_start_thc - Perform random start for CP3. If false perform HOSVD start. verify - check eri properties. Default is False + penalty_param - penalty parameter for L2 regularization. Default is None. returns: eri_thc - (N x N x N x N) reconstructed ERIs from THC factorization @@ -115,7 +117,9 @@ def thc_via_cp3( if perform_bfgs_opt: x = np.hstack((thc_leaf.ravel(), thc_central.ravel())) # lbfgs_start_time = time.time() - x = lbfgsb_opt_thc_l2reg(eri_full, nthc, initial_guess=x, maxiter=bfgs_maxiter) + x = lbfgsb_opt_thc_l2reg( + eri_full, nthc, initial_guess=x, maxiter=bfgs_maxiter, penalty_param=penalty_param + ) # lbfgs_calc_time = time.time() - lbfgs_start_time thc_leaf = x[: norb * nthc].reshape(nthc, norb) # leaf tensor nthc x norb thc_central = x[norb * nthc : norb * nthc + nthc * nthc].reshape( diff --git a/src/openfermion/resource_estimates/thc/utils/thc_factorization.py b/src/openfermion/resource_estimates/thc/utils/thc_factorization.py index 922b6317..16f7ae6a 100644 --- a/src/openfermion/resource_estimates/thc/utils/thc_factorization.py +++ b/src/openfermion/resource_estimates/thc/utils/thc_factorization.py @@ -1,4 +1,5 @@ # coverage:ignore +# pylint: disable=wrong-import-position import os from uuid import uuid4 import h5py @@ -6,9 +7,12 @@ import numpy.random import numpy.linalg from scipy.optimize import minimize + import jax + +jax.config.update("jax_enable_x64", True) + import jax.numpy as jnp -from jax.config import config from jax import jit, grad from .adagrad import adagrad from .thc_objectives import ( @@ -22,7 +26,6 @@ # set mkl thread count for numpy einsum/tensordot calls # leave one CPU un used so we can still access this computer os.environ["MKL_NUM_THREADS"] = "{}".format(os.cpu_count() - 1) -config.update("jax_enable_x64", True) class CallBackStore: diff --git a/src/openfermion/resource_estimates/thc/utils/thc_objectives.py b/src/openfermion/resource_estimates/thc/utils/thc_objectives.py index 3cbdd96a..366f25b2 100644 --- a/src/openfermion/resource_estimates/thc/utils/thc_objectives.py +++ b/src/openfermion/resource_estimates/thc/utils/thc_objectives.py @@ -1,9 +1,14 @@ # coverage:ignore +# pylint: disable=wrong-import-position import os from uuid import uuid4 import scipy.optimize + +import jax + +jax.config.update("jax_enable_x64", True) + import jax.numpy as jnp -from jax.config import config from jax import jit, grad import h5py import numpy @@ -15,7 +20,6 @@ # set mkl thread count for numpy einsum/tensordot calls # leave one CPU un used so we can still access this computer os.environ["MKL_NUM_THREADS"] = "{}".format(os.cpu_count() - 1) -config.update("jax_enable_x64", True) def thc_objective_jax(xcur, norb, nthc, eri):