-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add mean/std/ICM and confidence intervals to the markdown summary scr…
…ipt.
- Loading branch information
Showing
1 changed file
with
124 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,157 @@ | ||
"""Generate a markdown summary of the results of a benchmarking run.""" | ||
import argparse | ||
import pathlib | ||
import sys | ||
from collections import Counter | ||
from functools import lru_cache | ||
from typing import Generator, Sequence, cast | ||
|
||
import datasets | ||
import numpy as np | ||
from huggingface_sb3 import EnvironmentName | ||
from rliable import library as rly | ||
from rliable import metrics | ||
|
||
from imitation.data import rollout, types | ||
from imitation.data.huggingface_utils import TrajectoryDatasetSequence | ||
from imitation.util.sacred_file_parsing import ( | ||
find_sacred_runs, | ||
group_runs_by_algo_and_env, | ||
) | ||
|
||
|
||
def print_markdown_summary(path: pathlib.Path): | ||
@lru_cache(maxsize=None) | ||
def get_random_agent_score(env: str): | ||
stats = rollout.rollout_stats( | ||
cast( | ||
Sequence[types.TrajectoryWithRew], | ||
TrajectoryDatasetSequence( | ||
datasets.load_dataset( | ||
f"HumanCompatibleAI/random-{EnvironmentName(env)}", | ||
)["train"], | ||
), | ||
), | ||
) | ||
return stats["monitor_return_mean"] | ||
|
||
|
||
def print_markdown_summary(path: pathlib.Path) -> Generator[str, None, None]: | ||
if not path.exists(): | ||
raise NotADirectoryError(f"Path {path} does not exist.") | ||
|
||
print("# Benchmark Summary") | ||
yield "# Benchmark Summary" | ||
yield "" | ||
yield ( | ||
f"This is a summary of the sacred runs in `{path}` generated by " | ||
f"`sacred_output_to_markdown_summary.py`." | ||
) | ||
|
||
runs_by_algo_and_env = group_runs_by_algo_and_env(path) | ||
algos = sorted(runs_by_algo_and_env.keys()) | ||
|
||
print("## Run status" "") | ||
print("Status | Count") | ||
print("--- | ---") | ||
status_counts = Counter((run["status"] for _, run in find_sacred_runs(path))) | ||
statuses = sorted(list(status_counts)) | ||
for status in statuses: | ||
print(f"{status} | {status_counts[status]}") | ||
print() | ||
# Note: we only print the status section if there are multiple statuses | ||
if not (len(statuses) == 1 and statuses[0] == "COMPLETED"): | ||
yield "## Run status" "" | ||
yield "Status | Count" | ||
yield "--- | ---" | ||
for status in statuses: | ||
yield f"{status} | {status_counts[status]}" | ||
yield "" | ||
|
||
yield "## Detailed Run Status" | ||
yield f"Algorithm | Environment | {' | '.join(statuses)}" | ||
yield "--- | --- " + " | --- " * len(statuses) | ||
for algo in algos: | ||
envs = sorted(runs_by_algo_and_env[algo].keys()) | ||
for env in envs: | ||
status_counts = Counter( | ||
(run["status"] for run in runs_by_algo_and_env[algo][env]), | ||
) | ||
yield ( | ||
f"{algo} | {env} | " | ||
f"{' | '.join([str(status_counts[status]) for status in statuses])}" | ||
) | ||
|
||
yield "## Scores" | ||
yield "" | ||
yield ( | ||
"The scores are normalized based on the performance of a random agent as the" | ||
" baseline and the expert as the maximum possible score as explained " | ||
"[in this blog post](https://araffin.github.io/post/rliable/):" | ||
) | ||
yield "> `(score - random_score) / (expert_score - random_score)`" | ||
yield "" | ||
yield ( | ||
"Aggregate scores and confidence intervals are computed using the " | ||
"[rliable library](https://agarwl.github.io/rliable/)." | ||
) | ||
|
||
print("## Detailed Run Status") | ||
print(f"Algorithm | Environment | {' | '.join(sorted(list(status_counts)))}") | ||
print("--- | --- " + " | --- " * len(statuses)) | ||
for algo in algos: | ||
envs = sorted(runs_by_algo_and_env[algo].keys()) | ||
for env in envs: | ||
status_counts = Counter( | ||
(run["status"] for run in runs_by_algo_and_env[algo][env]), | ||
) | ||
print( | ||
f"{algo} | {env} | " | ||
f"{' | '.join([str(status_counts[status]) for status in statuses])}", | ||
) | ||
print() | ||
print("## Raw Scores") | ||
print() | ||
for algo in algos: | ||
print(f"### {algo.upper()}") | ||
print("Environment | Scores | Expert Scores") | ||
print("--- | --- | ---") | ||
yield f"### {algo.upper()}" | ||
yield "Environment | Score (mean/std)| Normalized Score (mean/std) | N" | ||
yield " --- | --- | --- | --- " | ||
envs = sorted(runs_by_algo_and_env[algo].keys()) | ||
accumulated_normalized_scores = [] | ||
for env in envs: | ||
completed_runs = [ | ||
run | ||
for run in runs_by_algo_and_env[algo][env] | ||
if run["status"] == "COMPLETED" | ||
] | ||
algo_scores = [ | ||
scores = [ | ||
run["result"]["imit_stats"]["monitor_return_mean"] | ||
for run in completed_runs | ||
for run in runs_by_algo_and_env[algo][env] | ||
] | ||
expert_scores = [ | ||
run["result"]["expert_stats"]["monitor_return_mean"] | ||
for run in completed_runs | ||
for run in runs_by_algo_and_env[algo][env] | ||
] | ||
print( | ||
random_score = get_random_agent_score(env) | ||
normalized_score = [ | ||
(score - random_score) / (expert_score - random_score) | ||
for score, expert_score in zip(scores, expert_scores) | ||
] | ||
accumulated_normalized_scores.append(normalized_score) | ||
|
||
yield ( | ||
f"{env} | " | ||
f"{', '.join([f'{score:.2f}' for score in algo_scores])} | " | ||
f"{', '.join([f'{score:.2f}' for score in expert_scores])}", | ||
f"{np.mean(scores):.3f} / {np.std(scores):.3f} | " | ||
f"{np.mean(normalized_score):.3f} / {np.std(normalized_score):.3f} | " | ||
f"{len(scores)}" | ||
) | ||
print() | ||
|
||
aggregate_scores, aggregate_score_cis = rly.get_interval_estimates( | ||
{"normalized_score": np.asarray(accumulated_normalized_scores).T}, | ||
lambda x: np.array([metrics.aggregate_mean(x), metrics.aggregate_iqm(x)]), | ||
reps=1000, | ||
) | ||
yield "" | ||
yield "#### Aggregate Normalized scores" | ||
|
||
yield "Metric | Value | 95% CI" | ||
yield " --- | --- | --- " | ||
yield ( | ||
f"Mean | " | ||
f"{aggregate_scores['normalized_score'][0]:.3f} | " | ||
f"[{aggregate_score_cis['normalized_score'][0][0]:.3f}, " | ||
f"{aggregate_score_cis['normalized_score'][0][1]:.3f}]" | ||
) | ||
yield ( | ||
f"IQM | " | ||
f"{aggregate_scores['normalized_score'][1]:.3f} | " | ||
f"[{aggregate_score_cis['normalized_score'][1][0]:.3f}, " | ||
f"{aggregate_score_cis['normalized_score'][1][1]:.3f}]" | ||
) | ||
yield "" | ||
|
||
|
||
if __name__ == "__main__": | ||
if len(sys.argv) != 2: | ||
print(f"Usage: {sys.argv[0]} <path to sacred run folder>") | ||
sys.exit(1) | ||
parser = argparse.ArgumentParser( | ||
description="Generate a markdown summary of the results of a benchmarking run.", | ||
) | ||
parser.add_argument("path", type=pathlib.Path) | ||
parser.add_argument("--output", type=pathlib.Path, default="summary.md") | ||
|
||
args = parser.parse_args() | ||
|
||
print_markdown_summary(pathlib.Path(sys.argv[1])) | ||
with open(args.output, "w") as fh: | ||
for line in print_markdown_summary(pathlib.Path(args.path)): | ||
fh.write(line) | ||
fh.write("\n") | ||
fh.flush() |