Skip to content

Commit

Permalink
fix: complex decomposition not used
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie committed Jun 30, 2023
1 parent bc6d1c6 commit 2086546
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 27 deletions.
86 changes: 86 additions & 0 deletions tests/models/test_decomposer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import numpy as np
import xarray as xr
import pytest
from dask.array import Array as DaskArray # type: ignore
from sklearn.utils.extmath import randomized_svd as svd
from scipy.sparse.linalg import svds as complex_svd # type: ignore
from dask.array.linalg import svd_compressed as dask_svd
from xeofs.models.decomposer import Decomposer, CrossDecomposer


@pytest.fixture
def decomposer():
return Decomposer(n_modes=2, n_iter=3, random_state=42, verbose=False)

@pytest.fixture
def cross_decomposer():
return CrossDecomposer(n_modes=2, n_iter=3, random_state=42, verbose=False)

@pytest.fixture
def test_DataArray(test_DataArray):
return test_DataArray.stack(sample=('time',), feature=('x', 'y')).dropna('feature')

@pytest.fixture
def test_DaskDataArray(test_DataArray):
return test_DataArray.chunk({'sample': 1})

@pytest.fixture
def test_complex_DataArray(test_DataArray):
return test_DataArray * (1 + 1j)

@pytest.fixture
def test_complex_DaskDataArray(test_complex_DataArray):
return test_complex_DataArray.chunk({'sample': 1})


def test_decomposer_init(decomposer):
assert decomposer.params['n_modes'] == 2
assert decomposer.params['n_iter'] == 3
assert decomposer.params['random_state'] == 42
assert decomposer.params['verbose'] == False

def test_cross_decomposer_init(cross_decomposer):
assert cross_decomposer.params['n_modes'] == 2
assert cross_decomposer.params['n_iter'] == 3
assert cross_decomposer.params['random_state'] == 42
assert cross_decomposer.params['verbose'] == False

def test_decomposer_fit(decomposer, test_DataArray):
decomposer.fit(test_DataArray)
assert 'scores_' in decomposer.__dict__
assert 'singular_values_' in decomposer.__dict__
assert 'components_' in decomposer.__dict__

def test_decomposer_fit_dask(decomposer, test_DaskDataArray):
decomposer.fit(test_DaskDataArray)
assert 'scores_' in decomposer.__dict__
assert 'singular_values_' in decomposer.__dict__
assert 'components_' in decomposer.__dict__

def test_decomposer_fit_complex(decomposer, test_complex_DataArray):
decomposer.fit(test_complex_DataArray)
assert 'scores_' in decomposer.__dict__
assert 'singular_values_' in decomposer.__dict__
assert 'components_' in decomposer.__dict__

def test_cross_decomposer_fit(cross_decomposer, test_DataArray):
cross_decomposer.fit(test_DataArray, test_DataArray)
assert 'singular_vectors1_' in cross_decomposer.__dict__
assert 'singular_values_' in cross_decomposer.__dict__
assert 'singular_vectors2_' in cross_decomposer.__dict__

def test_cross_decomposer_fit_complex(cross_decomposer, test_complex_DataArray):
cross_decomposer.fit(test_complex_DataArray, test_complex_DataArray)
assert 'singular_vectors1_' in cross_decomposer.__dict__
assert 'singular_values_' in cross_decomposer.__dict__
assert 'singular_vectors2_' in cross_decomposer.__dict__

def test_cross_decomposer_fit_dask(cross_decomposer, test_DaskDataArray):
cross_decomposer.fit(test_DaskDataArray, test_DaskDataArray)
assert 'singular_vectors1_' in cross_decomposer.__dict__
assert 'singular_values_' in cross_decomposer.__dict__
assert 'singular_vectors2_' in cross_decomposer.__dict__

def test_cross_decomposer_fit_same_samples(cross_decomposer, test_DataArray):
with pytest.raises(ValueError):
cross_decomposer.fit(test_DataArray, test_DataArray.isel(sample=slice(1,3)))
48 changes: 21 additions & 27 deletions xeofs/models/decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ class Decomposer():
Parameters
----------
n_components : int
n_modes : int
Number of components to be computed.
allow_complex : bool
If True, the data is allowed to be complex. If False, the data is assumed to be real.
n_iter : int
Number of iterations for the SVD algorithm.
random_state : int
Expand All @@ -25,11 +23,9 @@ class Decomposer():
If True, print information about the SVD algorithm.
'''
def __init__(self, n_components=100, allow_complex=False, n_iter=5, random_state=None, verbose=False):
def __init__(self, n_modes=100, n_iter=5, random_state=None, verbose=False):
self.params = {
'n_components': n_components,
'allow_complex': allow_complex,
'use_dask': False,
'n_modes': n_modes,
'n_iter': n_iter,
'random_state': random_state,
'verbose': verbose,
Expand All @@ -38,12 +34,12 @@ def __init__(self, n_components=100, allow_complex=False, n_iter=5, random_state
def fit(self, X):
svd_kwargs = {}

if isinstance(X.data, DaskArray):
self.params['use_dask'] = True
is_dask = True if isinstance(X.data, DaskArray) else False
is_complex = True if np.iscomplexobj(X.data) else False

if (not self.params['allow_complex']) and (not self.params['use_dask']):
if (not is_complex) and (not is_dask):
svd_kwargs.update({
'n_components': self.params['n_components'],
'n_components': self.params['n_modes'],
'random_state': self.params['random_state']
})

Expand All @@ -55,11 +51,11 @@ def fit(self, X):
output_core_dims=[['sample', 'mode'], ['mode'], ['mode', 'feature']],
)

elif (self.params['allow_complex']) and (not self.params['use_dask']):
elif is_complex and (not is_dask):
# Scipy sparse version
svd_kwargs.update({
'solver': 'lobpcg',
'k': self.params['n_components'],
'k': self.params['n_modes'],
})
U, s, VT = xr.apply_ufunc(
complex_svd,
Expand All @@ -73,9 +69,9 @@ def fit(self, X):
s = s[idx_sort]
VT = VT[idx_sort, :]

elif (not self.params['allow_complex']) and (self.params['use_dask']):
elif (not is_complex) and is_dask:
svd_kwargs.update({
'k': self.params['n_components']
'k': self.params['n_modes']
})
U, s, VT = xr.apply_ufunc(
dask_svd,
Expand Down Expand Up @@ -125,10 +121,8 @@ class CrossDecomposer(Decomposer):
Parameters
----------
n_components : int
n_modes : int
Number of components to be computed.
allow_complex : bool
If True, the data is allowed to be complex. If False, the data is assumed to be real.
n_iter : int
Number of iterations for the SVD algorithm.
random_state : int
Expand Down Expand Up @@ -157,13 +151,13 @@ def fit(self, X1, X2):
# Compute squared total variance
self.squared_total_variance_ = (cov_matrix**2).sum().compute()

if isinstance(cov_matrix.data, DaskArray):
self.params['use_dask'] = True

is_dask = True if isinstance(cov_matrix.data, DaskArray) else False
is_complex = True if np.iscomplexobj(cov_matrix.data) else False
svd_kwargs = {}
if (not self.params['allow_complex']) and (not self.params['use_dask']):
if (not is_complex) and (not is_dask):
svd_kwargs.update({
'n_components': self.params['n_components'],
'n_components': self.params['n_modes'],
'random_state': self.params['random_state']
})

Expand All @@ -175,11 +169,11 @@ def fit(self, X1, X2):
output_core_dims=[['feature1', 'mode'], ['mode'], ['mode', 'feature2']],
)

elif (self.params['allow_complex']) and (not self.params['use_dask']):
elif (is_complex) and (not is_dask):
# Scipy sparse version
svd_kwargs.update({
'solver': 'lobpcg',
'k': self.params['n_components'],
'k': self.params['n_modes'],
})
U, s, VT = xr.apply_ufunc(
complex_svd,
Expand All @@ -193,9 +187,9 @@ def fit(self, X1, X2):
s = s[idx_sort]
VT = VT[idx_sort, :]

elif (not self.params['allow_complex']) and (self.params['use_dask']):
elif (not is_complex) and (is_dask):
svd_kwargs.update({
'k': self.params['n_components']
'k': self.params['n_modes']
})
U, s, VT = xr.apply_ufunc(
dask_svd,
Expand Down

0 comments on commit 2086546

Please sign in to comment.