diff --git a/WORKSPACE b/WORKSPACE index 941d88963..5993dc07d 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -277,9 +277,12 @@ http_archive( http_archive( name = "csmith", build_file_content = all_content, - sha256 = "ba871c1e5a05a71ecd1af514fedba30561b16ee80b8dd5ba8f884eaded47009f", + sha256 = "86d08c19a1f123054ed9350b7962edaad7d46612de0e07c10b73e578911156fd", strip_prefix = "csmith-csmith-2.3.0", - urls = ["https://github.com/csmith-project/csmith/archive/refs/tags/csmith-2.3.0.tar.gz"], + urls = [ + "https://github.com/ChrisCummins/csmith/archive/refs/tags/csmith-2.3.0.tar.gz", + "https://github.com/csmith-project/csmith/archive/refs/tags/csmith-2.3.0.tar.gz", + ], ) # === DeepDataFlow === diff --git a/examples/example_compiler_gym_service/__init__.py b/examples/example_compiler_gym_service/__init__.py index 4d8965334..8cb7ab614 100644 --- a/examples/example_compiler_gym_service/__init__.py +++ b/examples/example_compiler_gym_service/__init__.py @@ -74,7 +74,7 @@ def benchmark_uris(self) -> Iterable[str]: def benchmark_from_parsed_uri(self, uri: BenchmarkUri) -> Benchmark: if uri.path in self._benchmarks: - return self._benchmarks[uri] + return self._benchmarks[uri.path] else: raise LookupError("Unknown program name") diff --git a/examples/loop_optimizations_service/__init__.py b/examples/loop_optimizations_service/__init__.py index 45ab1055a..b6866f07b 100644 --- a/examples/loop_optimizations_service/__init__.py +++ b/examples/loop_optimizations_service/__init__.py @@ -8,6 +8,7 @@ from typing import Iterable from compiler_gym.datasets import Benchmark, Dataset +from compiler_gym.datasets.uri import BenchmarkUri from compiler_gym.envs.llvm.llvm_benchmark import get_system_library_flags from compiler_gym.spaces import Reward from compiler_gym.third_party import llvm @@ -86,15 +87,15 @@ def __init__(self, *args, **kwargs): ) self._benchmarks = { - "benchmark://loops-opt-v0/add": Benchmark.from_file_contents( + "/add": Benchmark.from_file_contents( "benchmark://loops-opt-v0/add", self.preprocess(BENCHMARKS_PATH / "add.c"), ), - "benchmark://loops-opt-v0/offsets1": Benchmark.from_file_contents( + "/offsets1": Benchmark.from_file_contents( "benchmark://loops-opt-v0/offsets1", self.preprocess(BENCHMARKS_PATH / "offsets1.c"), ), - "benchmark://loops-opt-v0/conv2d": Benchmark.from_file_contents( + "/conv2d": Benchmark.from_file_contents( "benchmark://loops-opt-v0/conv2d", self.preprocess(BENCHMARKS_PATH / "conv2d.c"), ), @@ -122,11 +123,11 @@ def preprocess(src: Path) -> bytes: ) def benchmark_uris(self) -> Iterable[str]: - yield from self._benchmarks.keys() + yield from (f"benchmark://loops-opt-v0{k}" for k in self._benchmarks.keys()) - def benchmark(self, uri: str) -> Benchmark: - if uri in self._benchmarks: - return self._benchmarks[uri] + def benchmark_from_parsed_uri(self, uri: BenchmarkUri) -> Benchmark: + if uri.path in self._benchmarks: + return self._benchmarks[uri.path] else: raise LookupError("Unknown program name") diff --git a/tests/llvm/BUILD b/tests/llvm/BUILD index 32280ec76..c94d4fa98 100644 --- a/tests/llvm/BUILD +++ b/tests/llvm/BUILD @@ -177,6 +177,7 @@ py_test( "//compiler_gym/envs", "//compiler_gym/service/proto", "//tests:test_main", + "//tests/pytest_plugins:common", "//tests/pytest_plugins:llvm", ], ) @@ -232,6 +233,7 @@ py_test( deps = [ "//compiler_gym/envs", "//tests:test_main", + "//tests/pytest_plugins:common", "//tests/pytest_plugins:llvm", ], ) diff --git a/tests/llvm/CMakeLists.txt b/tests/llvm/CMakeLists.txt index bfe26234d..c541bd091 100644 --- a/tests/llvm/CMakeLists.txt +++ b/tests/llvm/CMakeLists.txt @@ -173,6 +173,7 @@ cg_py_test( DEPS compiler_gym::envs::envs compiler_gym::service::proto::proto + tests::pytest_plugins::common tests::pytest_plugins::llvm tests::test_main ) @@ -230,6 +231,7 @@ cg_py_test( "observation_spaces_test.py" DEPS compiler_gym::envs::envs + tests::pytest_plugins::common tests::pytest_plugins::llvm tests::test_main ) diff --git a/tests/llvm/gym_interface_compatability.py b/tests/llvm/gym_interface_compatability.py index 7b385a01b..ac22b164c 100644 --- a/tests/llvm/gym_interface_compatability.py +++ b/tests/llvm/gym_interface_compatability.py @@ -12,6 +12,10 @@ def test_type_classes(env: LlvmEnv): + env.observation_space = "Autophase" + env.reward_space = "IrInstructionCount" + env.reset() + assert isinstance(env, gym.Env) assert isinstance(env, LlvmEnv) assert isinstance(env.unwrapped, LlvmEnv)