Skip to content

Commit

Permalink
Adding shift parameter to fix #72
Browse files Browse the repository at this point in the history
  • Loading branch information
Samreay committed Sep 25, 2019
1 parent f5da4ab commit 53ef426
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ post, it can be solved by explicitly install the `matplotlib` dependency `dvipng

### Update History

##### 0.30.0
* Bug fix for specifying numeric `loc` to `legend_kwargs`
* Added `shift_params` when adding chains.

##### 0.29.1
* Potential bug fix for `log_space` feature.

Expand Down
10 changes: 10 additions & 0 deletions chainconsumer/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
marker_size=None,
marker_alpha=None,
zorder=None,
shift_params=None,
):
self.chain = chain
self.parameters = parameters
Expand All @@ -63,6 +64,15 @@ def __init__(
for i, p in enumerate(parameters):
self.posterior_max_params[p] = chain[self.posterior_max_index, i]

self.shift_params = shift_params
if shift_params is not None:
for key in shift_params.keys():
try:
index = self.parameters.index(key)
avg = np.average(chain[:, index], weights=weights)
chain[:, index] += shift_params[key] - avg
except ValueError:
continue
self.weights = weights
self.posterior = posterior
self.walkers = walkers
Expand Down
16 changes: 14 additions & 2 deletions chainconsumer/chainconsumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ChainConsumer(object):
"""

__version__ = "0.29.1"
__version__ = "0.30.0"

def __init__(self):
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -86,6 +86,7 @@ def add_chain(
cmap=None,
num_cloud=None,
zorder=None,
shift_params=None,
):
r""" Add a chain to the consumer.
Expand Down Expand Up @@ -184,7 +185,9 @@ def add_chain(
to colour scatter. Defaults to 15k per chain.
zorder : int, optional
The zorder to pass to `matplotlib` when plotting to determine visual order in the plot.
shift_params : dict|list, optional
Shifts the parameters specify to the numeric values. Useful to shift contours to the same location to perform blinded
uncertainty comparisons.
Returns
-------
ChainConsumer
Expand Down Expand Up @@ -251,6 +254,13 @@ def add_chain(
if p not in self._all_parameters:
self._all_parameters.append(p)

if shift_params is not None:
if isinstance(shift_params, list):
shift_params = dict([(p, s) for p, s in zip(parameters, shift_params)])
for key in shift_params.keys():
if key not in parameters:
self._logger.warning("Warning, shift parameter %s is not in list of parameters %s" % (key, parameters))

# Sorry, no KDE for you on a grid.
if grid:
kde = None
Expand Down Expand Up @@ -290,6 +300,7 @@ def add_chain(
cmap=cmap,
num_cloud=num_cloud,
zorder=zorder,
shift_params=shift_params,
)
self.chains.append(c)
self._init_params()
Expand Down Expand Up @@ -443,6 +454,7 @@ def configure(
watermark_text_kwargs=None,
summary_area=0.6827,
zorder=None,
stack=False,
): # pragma: no cover
r""" Configure the general plotting parameters common across the bar
and contour plots.
Expand Down
2 changes: 1 addition & 1 deletion chainconsumer/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def plot(
text.set_color(c)
if not outside:
loc = legend_kwargs.get("loc") or ""
if "right" in loc.lower():
if isinstance(loc, str) and "right" in loc.lower():
vp = leg._legend_box._children[-1]._children[0]
vp.align = "right"

Expand Down
46 changes: 46 additions & 0 deletions examples/customisations/plot_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
"""
==============
Shifting Plots
==============
Shift all your plots to the same location for blind uncertainty comparison.
Plots will shift to the location you tell them to, in the same format as a truth dictionary.
So you can use truth dict for both! Takes a list or a dict as input for convenience.
"""

import numpy as np
from numpy.random import multivariate_normal
from chainconsumer import ChainConsumer

np.random.seed(0)
data1 = multivariate_normal([1, 0], [[3, 2], [2, 3]], size=300000)
data2 = multivariate_normal([0, 0.5], [[1, -0.7], [-0.7, 1]], size=300000)
data3 = multivariate_normal([2, -1], [[0.5, 0], [0, 0.5]], size=300000)

###############################################################################
# And this is how easy it is to shift them:

truth = {"$x$": 1, "$y$": 0}
c = ChainConsumer()
c.add_chain(data1, parameters=["$x$", "$y$"], name="Chain A", shift_params=truth)
c.add_chain(data2, name="Chain B", shift_params=truth)
c.add_chain(data3, name="Chain C", shift_params=truth)
fig = c.plotter.plot(truth=truth)

fig.set_size_inches(2.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.

###############################################################################
# Here's without the shift:

truth = {"$x$": 1, "$y$": 0}
c = ChainConsumer()
c.add_chain(data1, parameters=["$x$", "$y$"], name="Chain A")
c.add_chain(data2, name="Chain B")
c.add_chain(data3, name="Chain C")
fig = c.plotter.plot(truth=truth)

fig.set_size_inches(2.5 + fig.get_size_inches()) # Resize fig for doco. You don't need this.

0 comments on commit 53ef426

Please sign in to comment.