Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add importance_sample method to NestedSamples and MCMCSamples #122

Merged
merged 12 commits into from
Aug 20, 2020
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ anesthetic: nested sampling visualisation
=========================================
:anesthetic: nested sampling visualisation
:Author: Will Handley and Lukas Hergt
:Version: 2.0.0-beta.2
:Version: 2.0.0-beta.3
:Homepage: https:/williamjameshandley/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
70 changes: 68 additions & 2 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,40 @@ def plot_2d(self, axes, *args, **kwargs):

return fig, axes

def importance_sample(self, logL_new, action='add'):
"""Perform importance re-weighting on the log-likelihood.

Parameters
----------
logL_new: np.array
New log-likelihood values. Should have the same shape as `logL`.

action: str
Can be any of {'add', 'replace', 'mask'}.
* add: Add the new `logL_new` to the current `logL`.
* replace: Replace the current `logL` with the new `logL_new`.
* mask: treat `logL_new` as a boolean mask and only keep the
corresponding (True) samples.
default: 'add'

Returns
-------
samples: MCMCSamples
Importance re-weighted samples.
"""
samples = self.copy()
if action == 'add':
samples.logL += logL_new
elif action == 'replace':
samples.logL = logL_new
elif action == 'mask':
samples = samples[logL_new]
else:
raise NotImplementedError("`action` needs to be one of "
"{'add', 'replace', 'mask'}, but '%s' "
"was requested." % action)
return samples

def _limits(self, paramname):
limits = self.limits.get(paramname, (None, None))
if limits[0] == limits[1]:
Expand Down Expand Up @@ -611,6 +645,36 @@ def dlogX(self, nsamples=None):
else:
return WeightedDataFrame(dlogX, self.index, weights=self.weights)

def importance_sample(self, logL_new, action='add'):
"""Perform importance re-weighting on the log-likelihood.

Parameters
----------
logL_new: np.array
New log-likelihood values. Should have the same shape as `logL`.

action: str
Can be any of {'add', 'replace', 'mask'}.
* add: Add the new `logL_new` to the current `logL`.
* replace: Replace the current `logL` with the new `logL_new`.
* mask: treat `logL_new` as a boolean mask and only keep the
corresponding (True) samples.
default: 'add'

Returns
-------
samples: NestedSamples
Importance re-weighted samples.
"""
samples = merge_nested_samples((self, ))
samples = super(NestedSamples, samples).importance_sample(
logL_new=logL_new, action=action
)
samples = merge_nested_samples(
(samples[samples.logL > samples.logL_birth], )
)
williamjameshandley marked this conversation as resolved.
Show resolved Hide resolved
return samples

def _compute_nlive(self, logL_birth):
if is_int(logL_birth):
nlive = logL_birth
Expand All @@ -637,12 +701,14 @@ def _constructor(self):


def merge_nested_samples(runs):
"""Merge two or more nested sampling runs.
"""Merge one or more nested sampling runs.

Parameters
----------
runs: list(NestedSamples)
list or array-like of nested sampling runs.
List or array-like of one or more nested sampling runs.
If only a single run is provided, this recalculates the live points and
as such can be used for masked runs.

Returns
-------
Expand Down
43 changes: 43 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,49 @@ def test_posterior_points():
assert_array_equal(ns.posterior_points(0.5), ns.posterior_points(0.5))


def test_importance_samples():
np.random.seed(3)
ns0 = NestedSamples(root='./tests/example_data/pc')
pi0 = ns0.set_beta(0)
NS0 = ns0.ns_output(nsamples=2000)

with pytest.raises(NotImplementedError):
ns0.importance_sample(ns0.logL, action='spam')
williamjameshandley marked this conversation as resolved.
Show resolved Hide resolved

ns_masked = ns0.importance_sample(ns0.logL, action='replace')
assert_array_equal(ns0.logL, ns_masked.logL)
assert_array_equal(ns0.logL_birth, ns_masked.logL_birth)
assert_array_equal(ns0.weights, ns_masked.weights)

ns_masked = ns0.importance_sample(np.zeros_like(ns0.logL), action='add')
assert_array_equal(ns0.logL, ns_masked.logL)
assert_array_equal(ns0.logL_birth, ns_masked.logL_birth)
assert_array_equal(ns0.weights, ns_masked.weights)

mask = ((ns0.x0 > -0.3) & (ns0.x2 > 0.2) & (ns0.x4 < 3.5)).to_numpy()
ns_masked = merge_nested_samples((ns0[mask], ))
V_prior = pi0[mask].weights.sum() / pi0.weights.sum()
V_posterior = ns0[mask].weights.sum() / ns0.weights.sum()

ns1 = ns0.importance_sample(mask, action='mask')
assert_array_equal(ns_masked.logL, ns1.logL)
assert_array_equal(ns_masked.logL_birth, ns1.logL_birth)
assert_array_equal(ns_masked.weights, ns1.weights)

logL_new = np.where(mask, 0, -np.inf)
ns1 = ns0.importance_sample(logL_new=logL_new)
NS1 = ns1.ns_output(nsamples=2000)
assert_array_equal(ns1, ns_masked)
logZ_V = NS0.logZ.mean() + np.log(V_posterior) - np.log(V_prior)
assert abs(NS1.logZ.mean() - logZ_V) < 1.5 * NS1.logZ.std()

logL_new = np.where(mask, 0, -1e30)
ns1 = ns0.importance_sample(logL_new=logL_new)
NS1 = ns1.ns_output(nsamples=2000)
logZ_V = NS0.logZ.mean() + np.log(V_posterior)
assert abs(NS1.logZ.mean() - logZ_V) < 1.5 * NS1.logZ.std()


def test_wedding_cake():
np.random.seed(3)
wc = WeddingCake(4, 0.5, 0.01)
Expand Down