Skip to content

Commit

Permalink
feat: reconstruct input data after rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie committed Feb 24, 2022
1 parent 7ed306a commit 0c9479e
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 6 deletions.
64 changes: 61 additions & 3 deletions xeofs/models/_base_rotator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import scipy as sc
from typing import Tuple
from typing import Optional, Union, List, Tuple

from .eof import EOF
from ..utils.rotation import promax
from ..utils.tools import get_mode_selector


class _BaseRotator:
Expand Down Expand Up @@ -147,8 +148,6 @@ def pcs(self, scaling : int = 0) -> np.ndarray:
pcs = self._pcs * np.sqrt(self._explained_variance * self._model.n_samples)
return pcs

return self._pcs

def eofs_as_correlation(self) -> Tuple[np.ndarray, np.ndarray]:
'''Correlation coefficients between rotated PCs and data matrix.
Expand All @@ -169,3 +168,62 @@ def eofs_as_correlation(self) -> Tuple[np.ndarray, np.ndarray]:
dist = sc.stats.beta(a, a, loc=-1, scale=2)
pvals = 2 * dist.cdf(-abs(corr))
return corr, pvals

def reconstruct_X(
self,
mode : Optional[Union[int, List[int], slice]] = None
) -> np.ndarray:
'''Reconstruct original data field ``X`` using the rotated PCs and EOFs.
If weights were applied, ``X`` will be automatically rescaled.
Parameters
----------
mode : Optional[Union[int, List[int], slice]]
Mode(s) based on which ``X`` will be reconstructed. If ``mode`` is
an int, a single mode is used. If a list of integers is provided,
use all specified modes for reconstruction. Alternatively, you may
want to select a slice to reconstruct. The first mode is denoted
by 1 (and not by 0). If None then ``X`` is recontructed using all
available modes (the default is None).
Examples
--------
Perform an analysis using some data ``X``:
>>> model = EOF(X, norm=True)
>>> model.solve()
Reconstruct ``X`` using all modes:
>>> model.reconstruct_X()
Reconstruct ``X`` using the first mode only:
>>> model.reconstruct_X(1)
Reconstruct ``X`` using mode 1, 3 and 4:
>>> model.reconstruct_X([1, 3, 4])
Reconstruct ``X`` using all modes up to mode 10 (including):
>>> model.reconstruct_X(slice(10))
Reconstruct ``X`` using every second mode between 4 and 8 (both
including):
>>> model.reconstruct_X(slice(4, 8, 2))
'''
eofs = self._eofs
pcs = self._pcs * np.sqrt(self._explained_variance * self._model.n_samples)
# Select modes to reconstruct X
mode = get_mode_selector(mode)
eofs = eofs[:, mode]
pcs = pcs[:, mode]
Xrec = pcs @ eofs.T
# Unweight and add mean
return (Xrec / self._model._weights) + self._model._X_mean
10 changes: 9 additions & 1 deletion xeofs/models/rotator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from typing import Tuple
from typing import Optional, Union, List, Tuple

from .eof import EOF
from ._base_rotator import _BaseRotator
Expand Down Expand Up @@ -59,3 +59,11 @@ def eofs_as_correlation(self) -> Tuple[np.ndarray, np.ndarray]:
corr = self._model._tf.back_transform_eofs(corr)
pvals = self._model._tf.back_transform_eofs(pvals)
return corr, pvals

def reconstruct_X(
self,
mode : Optional[Union[int, List[int], slice]] = None
) -> np.ndarray:
Xrec = super().reconstruct_X(mode=mode)
Xrec = self._model._tf.back_transform(Xrec)
return Xrec
11 changes: 10 additions & 1 deletion xeofs/pandas/rotator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pandas as pd
from typing import Tuple
from typing import Optional, Union, List, Tuple

from .eof import EOF
from ..models._base_rotator import _BaseRotator
Expand Down Expand Up @@ -68,3 +68,12 @@ def eofs_as_correlation(self) -> Tuple[pd.DataFrame, pd.DataFrame]:
corr.columns = self._model._idx_mode[:self._n_rot]
pvals.columns = self._model._idx_mode[:self._n_rot]
return corr, pvals

def reconstruct_X(
self,
mode : Optional[Union[int, List[int], slice]] = None
) -> pd.DataFrame:
Xrec = super().reconstruct_X(mode=mode)
Xrec = self._model._tf.back_transform(Xrec)
Xrec.index = self._model._tf.index_samples
return Xrec
13 changes: 12 additions & 1 deletion xeofs/xarray/rotator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import xarray as xr
from typing import Tuple
from typing import Tuple, Optional, Union, List

from .eof import EOF
from ..models._base_rotator import _BaseRotator
Expand Down Expand Up @@ -76,3 +76,14 @@ def eofs_as_correlation(self) -> Tuple[xr.DataArray, xr.DataArray]:
corr.name = 'correlation_coeffient'
pvals.name = 'p_value'
return corr, pvals

def reconstruct_X(
self,
mode : Optional[Union[int, List[int], slice]] = None
) -> xr.DataArray:
Xrec = super().reconstruct_X(mode=mode)
Xrec = self._model._tf.back_transform(Xrec)
coords_samples = {d: self._model._tf.coords[d] for d in self._model._tf.dims}
Xrec = Xrec.assign_coords(coords_samples)
Xrec.name = 'X_reconstructed'
return Xrec

0 comments on commit 0c9479e

Please sign in to comment.