Skip to content

Commit

Permalink
add importance_sample method to NestedSamples and MCMCSamples (#122)
Browse files Browse the repository at this point in the history
* add importance_reweighting method to NestedSamples

* version bump to 2.0.0-beta.3

* change add and replace args to action arg taking a string

* rename importance_reweighting to importance_sample

* append merge_nested_samples docstring to reflect additional utility for passing a single run

* add test for importance_samples

* add action='mask' to importance_sample and corresponding test

* add 'mask' to the sets in docstring and NotImplementedError message

* change action default from 'replace' to 'add'

* fix importance_sample tests

* move importance_sample to MCMCSamples and inherit in NestedSamples

* fix merge_nested_samples docstring: two->one
  • Loading branch information
Lukas Hergt authored Aug 20, 2020
1 parent 9b81fb9 commit 3db7903
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ anesthetic: nested sampling visualisation
=========================================
:anesthetic: nested sampling visualisation
:Author: Will Handley and Lukas Hergt
:Version: 2.0.0-beta.2
:Version: 2.0.0-beta.3
:Homepage: https:/williamjameshandley/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
70 changes: 68 additions & 2 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,40 @@ def plot_2d(self, axes, *args, **kwargs):

return fig, axes

def importance_sample(self, logL_new, action='add'):
"""Perform importance re-weighting on the log-likelihood.
Parameters
----------
logL_new: np.array
New log-likelihood values. Should have the same shape as `logL`.
action: str
Can be any of {'add', 'replace', 'mask'}.
* add: Add the new `logL_new` to the current `logL`.
* replace: Replace the current `logL` with the new `logL_new`.
* mask: treat `logL_new` as a boolean mask and only keep the
corresponding (True) samples.
default: 'add'
Returns
-------
samples: MCMCSamples
Importance re-weighted samples.
"""
samples = self.copy()
if action == 'add':
samples.logL += logL_new
elif action == 'replace':
samples.logL = logL_new
elif action == 'mask':
samples = samples[logL_new]
else:
raise NotImplementedError("`action` needs to be one of "
"{'add', 'replace', 'mask'}, but '%s' "
"was requested." % action)
return samples

def _limits(self, paramname):
limits = self.limits.get(paramname, (None, None))
if limits[0] == limits[1]:
Expand Down Expand Up @@ -611,6 +645,36 @@ def dlogX(self, nsamples=None):
else:
return WeightedDataFrame(dlogX, self.index, weights=self.weights)

def importance_sample(self, logL_new, action='add'):
"""Perform importance re-weighting on the log-likelihood.
Parameters
----------
logL_new: np.array
New log-likelihood values. Should have the same shape as `logL`.
action: str
Can be any of {'add', 'replace', 'mask'}.
* add: Add the new `logL_new` to the current `logL`.
* replace: Replace the current `logL` with the new `logL_new`.
* mask: treat `logL_new` as a boolean mask and only keep the
corresponding (True) samples.
default: 'add'
Returns
-------
samples: NestedSamples
Importance re-weighted samples.
"""
samples = merge_nested_samples((self, ))
samples = super(NestedSamples, samples).importance_sample(
logL_new=logL_new, action=action
)
samples = merge_nested_samples(
(samples[samples.logL > samples.logL_birth], )
)
return samples

def _compute_nlive(self, logL_birth):
if is_int(logL_birth):
nlive = logL_birth
Expand All @@ -637,12 +701,14 @@ def _constructor(self):


def merge_nested_samples(runs):
"""Merge two or more nested sampling runs.
"""Merge one or more nested sampling runs.
Parameters
----------
runs: list(NestedSamples)
list or array-like of nested sampling runs.
List or array-like of one or more nested sampling runs.
If only a single run is provided, this recalculates the live points and
as such can be used for masked runs.
Returns
-------
Expand Down
43 changes: 43 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,49 @@ def test_posterior_points():
assert_array_equal(ns.posterior_points(0.5), ns.posterior_points(0.5))


def test_importance_samples():
np.random.seed(3)
ns0 = NestedSamples(root='./tests/example_data/pc')
pi0 = ns0.set_beta(0)
NS0 = ns0.ns_output(nsamples=2000)

with pytest.raises(NotImplementedError):
ns0.importance_sample(ns0.logL, action='spam')

ns_masked = ns0.importance_sample(ns0.logL, action='replace')
assert_array_equal(ns0.logL, ns_masked.logL)
assert_array_equal(ns0.logL_birth, ns_masked.logL_birth)
assert_array_equal(ns0.weights, ns_masked.weights)

ns_masked = ns0.importance_sample(np.zeros_like(ns0.logL), action='add')
assert_array_equal(ns0.logL, ns_masked.logL)
assert_array_equal(ns0.logL_birth, ns_masked.logL_birth)
assert_array_equal(ns0.weights, ns_masked.weights)

mask = ((ns0.x0 > -0.3) & (ns0.x2 > 0.2) & (ns0.x4 < 3.5)).to_numpy()
ns_masked = merge_nested_samples((ns0[mask], ))
V_prior = pi0[mask].weights.sum() / pi0.weights.sum()
V_posterior = ns0[mask].weights.sum() / ns0.weights.sum()

ns1 = ns0.importance_sample(mask, action='mask')
assert_array_equal(ns_masked.logL, ns1.logL)
assert_array_equal(ns_masked.logL_birth, ns1.logL_birth)
assert_array_equal(ns_masked.weights, ns1.weights)

logL_new = np.where(mask, 0, -np.inf)
ns1 = ns0.importance_sample(logL_new=logL_new)
NS1 = ns1.ns_output(nsamples=2000)
assert_array_equal(ns1, ns_masked)
logZ_V = NS0.logZ.mean() + np.log(V_posterior) - np.log(V_prior)
assert abs(NS1.logZ.mean() - logZ_V) < 1.5 * NS1.logZ.std()

logL_new = np.where(mask, 0, -1e30)
ns1 = ns0.importance_sample(logL_new=logL_new)
NS1 = ns1.ns_output(nsamples=2000)
logZ_V = NS0.logZ.mean() + np.log(V_posterior)
assert abs(NS1.logZ.mean() - logZ_V) < 1.5 * NS1.logZ.std()


def test_wedding_cake():
np.random.seed(3)
wc = WeddingCake(4, 0.5, 0.01)
Expand Down

0 comments on commit 3db7903

Please sign in to comment.