Skip to content

Commit

Permalink
feat: complex MCA amplitude and phase
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie committed Jul 6, 2023
1 parent 083a8e0 commit 55ce3b1
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 2 deletions.
85 changes: 84 additions & 1 deletion tests/models/test_complex_mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,90 @@ def test_complex_mca_fit(mca_model, mock_data_array, dim):
assert mca_model._norm2 is not None


@pytest.mark.parametrize('dim', [
(('time',)),
(('lat', 'lon')),
(('lon', 'lat')),
])
def test_complex_mca_components(mca_model, mock_data_array, dim):
mca_model.fit(mock_data_array, mock_data_array, dim)
components = mca_model.components()
assert isinstance(components, tuple), 'components is not a tuple'
assert len(components) == 2, 'components list does not have 2 elements'
assert isinstance(components[0], xr.DataArray), 'components[0] is not a DataArray'
assert isinstance(components[1], xr.DataArray), 'components[1] is not a DataArray'


@pytest.mark.parametrize('dim', [
(('time',)),
(('lat', 'lon')),
(('lon', 'lat')),
])
def test_complex_mca_scores(mca_model, mock_data_array, dim):
mca_model.fit(mock_data_array, mock_data_array, dim)
scores = mca_model.scores()
assert isinstance(scores, tuple), 'scores is not a tuple'
assert len(scores) == 2, 'scores list does not have 2 elements'
assert isinstance(scores[0], xr.DataArray), 'scores[0] is not a DataArray'
assert isinstance(scores[1], xr.DataArray), 'scores[1] is not a DataArray'


@pytest.mark.parametrize('dim', [
(('time',)),
(('lat', 'lon')),
(('lon', 'lat')),
])
def test_complex_mca_components_amplitude(mca_model, mock_data_array, dim):
mca_model.fit(mock_data_array, mock_data_array, dim)
components = mca_model.components_amplitude()
assert isinstance(components, tuple), 'components is not a tuple'
assert len(components) == 2, 'components list does not have 2 elements'
assert isinstance(components[0], xr.DataArray), 'components[0] is not a DataArray'
assert isinstance(components[1], xr.DataArray), 'components[1] is not a DataArray'


@pytest.mark.parametrize('dim', [
(('time',)),
(('lat', 'lon')),
(('lon', 'lat')),
])
def test_complex_mca_components_phase(mca_model, mock_data_array, dim):
mca_model.fit(mock_data_array, mock_data_array, dim)
components = mca_model.components_phase()
assert isinstance(components, tuple), 'components is not a tuple'
assert len(components) == 2, 'components list does not have 2 elements'
assert isinstance(components[0], xr.DataArray), 'components[0] is not a DataArray'
assert isinstance(components[1], xr.DataArray), 'components[1] is not a DataArray'


@pytest.mark.parametrize('dim', [
(('time',)),
(('lat', 'lon')),
(('lon', 'lat')),
])
def test_complex_mca_scores_amplitude(mca_model, mock_data_array, dim):
mca_model.fit(mock_data_array, mock_data_array, dim)
scores = mca_model.scores_amplitude()
assert isinstance(scores, tuple), 'scores is not a tuple'
assert len(scores) == 2, 'scores list does not have 2 elements'
assert isinstance(scores[0], xr.DataArray), 'scores[0] is not a DataArray'
assert isinstance(scores[1], xr.DataArray), 'scores[1] is not a DataArray'


@pytest.mark.parametrize('dim', [
(('time',)),
(('lat', 'lon')),
(('lon', 'lat')),
])
def test_complex_mca_scores_phase(mca_model, mock_data_array, dim):
mca_model.fit(mock_data_array, mock_data_array, dim)
scores = mca_model.scores_phase()
assert isinstance(scores, tuple), 'scores is not a tuple'
assert len(scores) == 2, 'scores list does not have 2 elements'
assert isinstance(scores[0], xr.DataArray), 'scores[0] is not a DataArray'
assert isinstance(scores[1], xr.DataArray), 'scores[1] is not a DataArray'


@pytest.mark.parametrize('dim', [
(('time',)),
(('lat', 'lon')),
Expand Down Expand Up @@ -60,7 +144,6 @@ def test_complex_mca_transform_not_implemented(mca_model, mock_data_array, dim):
mca_model.transform(mock_data_array, mock_data_array)



def test_complex_mca_homogeneous_patterns_not_implemented():
mca = ComplexMCA()
with pytest.raises(NotImplementedError):
Expand Down
82 changes: 81 additions & 1 deletion xeofs/models/mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._base_cross_model import _BaseCrossModel
from .decomposer import CrossDecomposer
from ..utils.data_types import XarrayData, DataArrayList
from ..utils.data_types import XarrayData, DataArrayList, DataArray, Dataset
from ..utils.statistics import pearson_correlation
from ..utils.xarray_utils import hilbert_transform

Expand Down Expand Up @@ -489,6 +489,86 @@ def fit(self, data1: XarrayData | DataArrayList, data2: XarrayData | DataArrayLi
# Assign analysis relevant meta data
self._assign_meta_data()

def components_amplitude(self) -> DataArray | Dataset | DataArrayList:
'''Compute the amplitude of the components.
Returns
-------
xr.DataArray
Amplitude of the components.
'''
comps1 = abs(self._singular_vectors1)
comps2 = abs(self._singular_vectors2)

comps1.name = 'singular_vector_amplitudes'
comps2.name = 'singular_vector_amplitudes'

comps1 = self.stacker1.inverse_transform_components(comps1) # type: ignore
comps2 = self.stacker2.inverse_transform_components(comps2) # type: ignore

return comps1, comps2 # type: ignore

def components_phase(self) -> DataArray | Dataset | DataArrayList:
'''Compute the phase of the components.
Returns
-------
xr.DataArray
Phase of the components.
'''
comps1 = xr.apply_ufunc(np.angle, self._singular_vectors1, keep_attrs=True)
comps2 = xr.apply_ufunc(np.angle, self._singular_vectors2, keep_attrs=True)

comps1.name = 'singular_vector_phases'
comps2.name = 'singular_vector_phases'

comps1 = self.stacker1.inverse_transform_components(comps1) # type: ignore
comps2 = self.stacker2.inverse_transform_components(comps2) # type: ignore

return comps1, comps2 # type: ignore

def scores_amplitude(self) -> DataArray | Dataset | DataArrayList:
'''Compute the amplitude of the scores.
Returns
-------
xr.DataArray
Amplitude of the scores.
'''
scores1 = abs(self._scores1)
scores2 = abs(self._scores2)

scores1.name = 'score_amplitudes'
scores2.name = 'score_amplitudes'

scores1 = self.stacker1.inverse_transform_scores(scores1) # type: ignore
scores2 = self.stacker2.inverse_transform_scores(scores2) # type: ignore

return scores1, scores2 # type: ignore

def scores_phase(self) -> DataArray | Dataset | DataArrayList:
'''Compute the phase of the scores.
Returns
-------
xr.DataArray
Phase of the scores.
'''
scores1 = xr.apply_ufunc(np.angle, self._scores1, keep_attrs=True)
scores2 = xr.apply_ufunc(np.angle, self._scores2, keep_attrs=True)

scores1.name = 'score_phases'
scores2.name = 'score_phases'

scores1 = self.stacker1.inverse_transform_scores(scores1) # type: ignore
scores2 = self.stacker2.inverse_transform_scores(scores2) # type: ignore

return scores1, scores2 # type: ignore


def transform(self, data1: XarrayData | DataArrayList, data2: XarrayData | DataArrayList):
raise NotImplementedError('Complex MCA does not support transform method.')
Expand Down

0 comments on commit 55ce3b1

Please sign in to comment.