Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Better error message is proto is invalid. #628

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions compiler_gym/views/observation_space_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, ClassVar, Optional, Union

# import networkx as nx
# import numpy as np
from gym.spaces import Space

from compiler_gym.service.proto import Event, ObservationSpace, py_converters
from compiler_gym.util.gym_type_hints import ObservationType
from compiler_gym.util.shell_format import indent


class ObservationSpaceSpec:
Expand Down Expand Up @@ -95,10 +94,34 @@ def __eq__(self, rhs) -> bool:

@classmethod
def from_proto(cls, index: int, proto: ObservationSpace):
"""Create an observation space from a ObservationSpace protocol buffer.

:param index: The index of this observation space into the list of
observation spaces that the compiler service supports.

:param proto: An ObservationSpace protocol buffer.

:raises ValueError: If protocol buffer is invalid.
"""
try:
spec = ObservationSpaceSpec.message_converter(proto.space)
except ValueError as e:
raise ValueError(
f"Error interpreting description of observation space '{proto.name}'.\n"
f"Error: {e}\n"
f"ObservationSpace message:\n"
f"{indent(proto.space, n=2)}"
) from e

# TODO(cummins): Additional validation of the observation space
# specification would be useful here, such as making sure that the size
# of {low, high} tensors for box shapes match. At present, these errors
# tend not to show up until later, making it more difficult to debug.

return cls(
id=proto.name,
index=index,
space=ObservationSpaceSpec.message_converter(proto.space),
space=spec,
translate=ObservationSpaceSpec.message_converter,
to_string=str,
deterministic=proto.deterministic,
Expand Down