Skip to content

Commit

Permalink
Merge pull request #296 from mj-will/update-state-plot
Browse files Browse the repository at this point in the history
Improve state plot
  • Loading branch information
mj-will authored Apr 3, 2023
2 parents 77a6248 + 9859e8c commit 8e75c84
Showing 1 changed file with 21 additions and 29 deletions.
50 changes: 21 additions & 29 deletions nessai/samplers/nestedsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,35 +943,26 @@ def plot_state(self, filename=None):
returned.
"""

fig, ax = plt.subplots(6, 1, sharex=True, figsize=(12, 12))
fig, ax = plt.subplots(7, 1, sharex=True, figsize=(12, 12))
ax = ax.ravel()
it = (np.arange(len(self.min_likelihood))) * (self.nlive // 10)
it[-1] = self.iteration

for t in self.training_iterations:
for a in ax:
a.axvline(t, ls="-", color="lightgrey")

if not self.train_on_empty:
for p in self.population_iterations:
for a in ax:
a.axvline(p, ls="-", color="tab:orange")

for i in self.checkpoint_iterations:
for a in ax:
a.axvline(i, ls=":", color="#66ccff")

for a in ax:
a.axvline(self.iteration, c="#ff9900", ls="-.")

ax[0].plot(it, self.min_likelihood, label="Min logL")
ax[0].plot(it, self.max_likelihood, label="Max logL")
ax[0].set_ylabel("logL")
ax[0].plot(it, self.min_likelihood, label="Min log L")
ax[0].plot(it, self.max_likelihood, label="Max log L")
ax[0].set_ylabel(r"$\log L$")
ax[0].legend(frameon=False)

logX_its = np.arange(len(self.state.log_vols))
ax[1].plot(logX_its, self.state.log_vols, label="log X")
ax[1].set_ylabel("Log X")
ax[1].set_ylabel(r"$\log X$")
ax[1].legend(frameon=False)

if self.state.track_gradients:
Expand All @@ -993,10 +984,10 @@ def plot_state(self, filename=None):
)

ax[2].plot(it, self.likelihood_evaluations, label="Evaluations")
ax[2].set_ylabel("logL evaluations")
ax[2].set_ylabel("Likelihood\n evaluations")

ax[3].plot(it, self.logZ_history, label="logZ")
ax[3].set_ylabel("logZ")
ax[3].set_ylabel(r"$\log Z$")
ax[3].legend(frameon=False)

ax_dz = plt.twinx(ax[3])
Expand All @@ -1007,7 +998,7 @@ def plot_state(self, filename=None):
c="C1",
ls=config.plotting.line_styles[1],
)
ax_dz.set_ylabel("dZ")
ax_dz.set_ylabel(r"$dZ$")
handles, labels = ax[3].get_legend_handles_labels()
handles_dz, labels_dz = ax_dz.get_legend_handles_labels()
ax[3].legend(handles + handles_dz, labels + labels_dz, frameon=False)
Expand All @@ -1034,11 +1025,21 @@ def plot_state(self, filename=None):
handles_r, labels_r = ax_r.get_legend_handles_labels()
ax[4].legend(handles + handles_r, labels + labels_r, frameon=False)

dtrain = np.array(self.training_iterations[1:]) - np.array(
self.training_iterations[:-1]
)
ax[5].plot(self.training_iterations[1:], dtrain)
if self.training_iterations:
ax[5].axvline(
self.training_iterations[0], ls="-", color="lightgrey"
)
ax[5].set_ylabel(r"$\Delta$ train")

if len(self.rolling_p):
it = (np.arange(len(self.rolling_p)) + 1) * self.nlive
ax[5].plot(it, self.rolling_p, "o", label="p-value")
ax[5].set_ylabel("p-value")
ax[5].set_ylim([-0.1, 1.1])
ax[6].plot(it, self.rolling_p, "o", label="p-value")
ax[6].set_ylabel("p-value")
ax[6].set_ylim([-0.1, 1.1])

ax[-1].set_xlabel("Iteration")

Expand All @@ -1054,15 +1055,6 @@ def plot_state(self, filename=None):
linestyle="-.",
label="Current iteration",
),
Line2D(
[0],
[0],
color="lightgrey",
linestyle="-",
markersize=10,
markeredgewidth=1.5,
label="Training",
),
Line2D(
[0], [0], color="#66ccff", linestyle=":", label="Checkpoint"
),
Expand Down

0 comments on commit 8e75c84

Please sign in to comment.