Skip to content

Commit

Permalink
Merge pull request #628 from ChrisCummins/feature/observation-space-s…
Browse files Browse the repository at this point in the history
…pec-errors

[core] Better error message is proto is invalid.
  • Loading branch information
ChrisCummins authored Mar 17, 2022
2 parents 6bf0209 + b4217a9 commit d22fd36
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions compiler_gym/views/observation_space_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, ClassVar, Optional, Union

# import networkx as nx
# import numpy as np
from gym.spaces import Space

from compiler_gym.service.proto import Event, ObservationSpace, py_converters
from compiler_gym.util.gym_type_hints import ObservationType
from compiler_gym.util.shell_format import indent


class ObservationSpaceSpec:
Expand Down Expand Up @@ -95,10 +94,34 @@ def __eq__(self, rhs) -> bool:

@classmethod
def from_proto(cls, index: int, proto: ObservationSpace):
"""Create an observation space from a ObservationSpace protocol buffer.
:param index: The index of this observation space into the list of
observation spaces that the compiler service supports.
:param proto: An ObservationSpace protocol buffer.
:raises ValueError: If protocol buffer is invalid.
"""
try:
spec = ObservationSpaceSpec.message_converter(proto.space)
except ValueError as e:
raise ValueError(
f"Error interpreting description of observation space '{proto.name}'.\n"
f"Error: {e}\n"
f"ObservationSpace message:\n"
f"{indent(proto.space, n=2)}"
) from e

# TODO(cummins): Additional validation of the observation space
# specification would be useful here, such as making sure that the size
# of {low, high} tensors for box shapes match. At present, these errors
# tend not to show up until later, making it more difficult to debug.

return cls(
id=proto.name,
index=index,
space=ObservationSpaceSpec.message_converter(proto.space),
space=spec,
translate=ObservationSpaceSpec.message_converter,
to_string=str,
deterministic=proto.deterministic,
Expand Down

0 comments on commit d22fd36

Please sign in to comment.