Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new mypy tags and fixed all new mypy errors #117

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default:

SOURCE_DIR_NAME=vistautils

MYPY:=mypy $(SOURCE_DIR_NAME) tests
MYPY:=mypy $(SOURCE_DIR_NAME) tests --disallow-any-generics --disallow-any-unimported
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be adding these to mypy.ini.


# Suppressed warnings:
# Too many arguments, Unexpected keyword arguments: can't do static analysis on attrs __init__
Expand Down
5 changes: 3 additions & 2 deletions tests/test_range.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from typing import Any
from unittest import TestCase

from immutablecollections import ImmutableSet
Expand Down Expand Up @@ -233,15 +234,15 @@ def check_contains(self, rng: Range[int]) -> None:
self.assertTrue(7 in rng)
self.assertFalse(8 in rng)

def assert_unbounded_below(self, rng: Range):
def assert_unbounded_below(self, rng: Range[Any]):
self.assertFalse(rng.has_lower_bound())
with self.assertRaises(ValueError):
rng.lower_endpoint()
# pylint: disable=pointless-statement
with self.assertRaises(AssertionError):
rng.lower_bound_type

def assert_unbounded_above(self, rng: Range):
def assert_unbounded_above(self, rng: Range[Any]):
self.assertFalse(rng.has_upper_bound())
with self.assertRaises(ValueError):
rng.upper_endpoint()
Expand Down
20 changes: 11 additions & 9 deletions vistautils/attrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import vistautils.preconditions


def attrib_instance_of(type_: Union[Type, Tuple[Type, ...]], *args, **kwargs):
def attrib_instance_of(type_: Union[Type[Any], Tuple[Type[Any], ...]], *args, **kwargs):
warnings.warn(
"Deprecated, use attrib(validator=instance_of(type))", DeprecationWarning
)
Expand All @@ -22,7 +22,7 @@ def attrib_instance_of(type_: Union[Type, Tuple[Type, ...]], *args, **kwargs):
# TODO cannot currently be used with additional validators:
# https:/isi-nlp/isi-flexnlp/issues/188
def attrib_opt_instance_of(
type_: Union[Type, Tuple[Type, ...]], *args, default=None, **kwargs
type_: Union[Type[Any], Tuple[Type[Any], ...]], *args, default=None, **kwargs
):
warnings.warn(
"Deprecated, use attrib(default=<default>, validator=optional(instance_of(<type>)))",
Expand All @@ -34,14 +34,14 @@ def attrib_opt_instance_of(
)


def attrib_factory(factory: Callable, *args, **kwargs):
def attrib_factory(factory: Callable[[Any], Any], *args, **kwargs):
warnings.warn("Deprecated, use attrib(factory=<factory>)", DeprecationWarning)
# Mypy does not understand these arguments
return attrib(default=Factory(factory), *args, **kwargs) # type: ignore


def attrib_immutable(
type_: Type[immutablecollections.ImmutableCollection], *args, **kwargs
type_: Type[immutablecollections.ImmutableCollection[Any]], *args, **kwargs
):
warnings.warn(
"Deprecated, use attrib(converter=<collection factory>)", DeprecationWarning
Expand All @@ -52,7 +52,7 @@ def attrib_immutable(


def attrib_private_immutable_builder(
type_: Type[immutablecollections.ImmutableCollection], *args, **kwargs
type_: Type[immutablecollections.ImmutableCollection[Any]], *args, **kwargs
):
"""
Create an immutable collection builder private attribute.
Expand All @@ -75,7 +75,7 @@ def attrib_private_immutable_builder(
# TODO: The use of Type[ImmutableCollection] causes Mypy warnings
# Perhaps the solution is to make ImmutableCollection a Protocol?
def attrib_opt_immutable(
type_: Type[immutablecollections.ImmutableCollection], *args, **kwargs
type_: Type[immutablecollections.ImmutableCollection[Any]], *args, **kwargs
):
"""Return a attrib with a converter for optional collections.

Expand All @@ -98,7 +98,9 @@ def attrib_opt_immutable(
)


def opt_instance_of(type_: Union[Type, Tuple[Type, ...]]) -> Callable:
def opt_instance_of(
type_: Union[Type[Any], Tuple[Type[Any], ...]]
) -> Callable[[Any], Any]:
warnings.warn("Deprecated, use optional(instance_of(<type>))", DeprecationWarning)
# Mypy does not understand these arguments
return validators.instance_of((type_, type(None))) # type: ignore
Expand All @@ -112,8 +114,8 @@ def _check_immutable_collection(type_):


def _empty_immutable_if_none(
val: Any, type_: Type[immutablecollections.ImmutableCollection]
) -> immutablecollections.ImmutableCollection:
val: Any, type_: Type[immutablecollections.ImmutableCollection[Any]]
) -> immutablecollections.ImmutableCollection[Any]:
if val is None:
return type_.empty()
else:
Expand Down
3 changes: 1 addition & 2 deletions vistautils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Any,
AnyStr,
BinaryIO,
Callable,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -229,7 +228,7 @@ def open(self) -> TextIO:
ret = CharSource.from_string(tgz_data.read().decode(self._encoding)).open()
# we need to fiddle with the close method on the returned TextIO so that when it is
# closed the containing zip file is closed as well
old_close: Callable = ret.close
old_close = ret.close
else:
raise IOError(
f"Could not extract path {self._path_within_tgz} from {self._tgz_path}"
Expand Down
14 changes: 8 additions & 6 deletions vistautils/key_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def zip_bytes_sink(
)


class KeyValueLinearSource(Generic[K, V], AbstractContextManager, metaclass=ABCMeta):
class KeyValueLinearSource(
Generic[K, V], AbstractContextManager, metaclass=ABCMeta # type: ignore
):
"""
Anything which provide a sequence of key-value pairs.

Expand Down Expand Up @@ -875,7 +877,7 @@ def byte_key_value_linear_source_from_params(
params: Parameters,
*,
input_namespace: str = "input",
eval_context: Optional[Dict] = None,
eval_context: Optional[Dict[Any, Any]] = None,
) -> KeyValueLinearSource[str, bytes]:
"""
Get a key-value source based on parameters.
Expand Down Expand Up @@ -918,7 +920,7 @@ def char_key_value_source_from_params(
params: Parameters,
*,
input_namespace: str = "input",
eval_context: Optional[Dict] = None,
eval_context: Optional[Dict[Any, Any]] = None,
) -> KeyValueSource[str, str]:
"""
Get a random-access key-value source based on parameters.
Expand Down Expand Up @@ -957,7 +959,7 @@ def byte_key_value_source_from_params(
params: Parameters,
*,
input_namespace: str = "input",
eval_context: Optional[Dict] = None,
eval_context: Optional[Dict[Any, Any]] = None,
) -> KeyValueSource[str, bytes]:
"""
Get a random-access key-value source based on parameters.
Expand Down Expand Up @@ -1007,7 +1009,7 @@ def char_key_value_sink_from_params(
params: Parameters,
*,
output_namespace: str = "output",
eval_context: Optional[Dict] = None,
eval_context: Optional[Dict[Any, Any]] = None,
) -> KeyValueSink[str, str]:
"""
Get a key-value sink based on parameters.
Expand Down Expand Up @@ -1044,7 +1046,7 @@ def byte_key_value_sink_from_params(
params: Parameters,
*,
output_namespace: str = "output",
eval_context: Optional[Dict] = None,
eval_context: Optional[Dict[Any, Any]] = None,
) -> KeyValueSink[str, bytes]:
"""
Get a binary key-value sink based on parameters.
Expand Down
22 changes: 12 additions & 10 deletions vistautils/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def empty(*, namespace_prefix: Iterable[str] = tuple()) -> "Parameters":

@staticmethod
def from_mapping(
mapping: Mapping, *, namespace_prefix: Iterable[str] = tuple()
mapping: Mapping[Any, Any], *, namespace_prefix: Iterable[str] = tuple()
) -> "Parameters":
"""
Convert a dictionary of dictionaries into a `Parameter`s
Expand Down Expand Up @@ -651,21 +651,23 @@ def namespace_or_empty(self, name: str) -> Optional["Parameters"]:
f"Expected a namespace, but got a regular parameters for {name}"
)

def arbitrary_list(self, name: str, *, default: Optional[List] = None) -> List:
def arbitrary_list(
self, name: str, *, default: Optional[List[Any]] = None
) -> List[Any]:
"""
Get a list with arbitrary structure.
"""
return self.get(name, List, default=default)

@overload
def optional_arbitrary_list(self, name: str) -> Optional[List]:
def optional_arbitrary_list(self, name: str) -> Optional[List[Any]]:
...

@overload
def optional_arbitrary_list(self, name: str, *, default: List) -> List:
def optional_arbitrary_list(self, name: str, *, default: List[Any]) -> List[Any]:
...

def optional_arbitrary_list(self, name: str, *, default: Optional[List] = None):
def optional_arbitrary_list(self, name: str, *, default: Optional[List[Any]] = None):
"""
Get a list with arbitrary structure, if available
"""
Expand Down Expand Up @@ -710,7 +712,7 @@ def evaluate(
name: str,
expected_type: Type[_ParamType],
*,
context: Optional[Mapping] = None,
context: Optional[Mapping[Any, Any]] = None,
namespace_param_name: str = "value",
special_values: Mapping[str, str] = ImmutableDict.empty(),
default: Optional[_ParamType] = None,
Expand Down Expand Up @@ -743,7 +745,7 @@ def handle_special_values(val: str) -> str:
namespace = self.optional_namespace(name)
try:
to_evaluate = None
context_modules: List = []
context_modules: List[Any] = []

if namespace:
to_evaluate = namespace.string(namespace_param_name)
Expand Down Expand Up @@ -773,7 +775,7 @@ def object_from_parameters(
name: str,
expected_type: Type[_ParamType],
*,
context: Optional[Mapping] = None,
context: Optional[Mapping[Any, Any]] = None,
creator_namepace_param_name: str = "value",
special_creator_values: Mapping[str, str] = ImmutableDict.empty(),
default_creator: Optional[Any] = None,
Expand Down Expand Up @@ -1255,7 +1257,7 @@ def _inner_load_from_string(
raise IOError(f"Failure while loading parameter file {error_string}") from e

@staticmethod
def _validate(raw_yaml: Mapping):
def _validate(raw_yaml: Mapping[Any, Any]):
# we don't use check_isinstance so we can have a custom error message
check_arg(
isinstance(raw_yaml, Mapping),
Expand All @@ -1264,7 +1266,7 @@ def _validate(raw_yaml: Mapping):
YAMLParametersLoader._check_all_keys_strings(raw_yaml)

@staticmethod
def _check_all_keys_strings(mapping: Mapping, path=None):
def _check_all_keys_strings(mapping: Mapping[Any, Any], path=None):
if path is None:
path = []

Expand Down
4 changes: 2 additions & 2 deletions vistautils/preconditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# pylint: disable=invalid-name
# Type annotation from TypeShed for classinfo argument of isinstance and issubclass

_ClassInfo = Union[type, Tuple[Union[type, Tuple], ...]]
_ClassInfo = Union[type, Tuple[Union[type, Tuple[Any, ...]], ...]]


T = TypeVar("T")
Expand All @@ -29,7 +29,7 @@ def check_not_none(x: T, msg: str = None) -> T:
return x


def check_arg(result: Any, msg: str = None, msg_args: Tuple = None) -> None:
def check_arg(result: Any, msg: str = None, msg_args: Tuple[Any] = None) -> None:
if not result:
if msg:
raise ValueError(msg % (msg_args or ()))
Expand Down
12 changes: 6 additions & 6 deletions vistautils/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def remove_all(
class _SortedDictRangeSet(RangeSet[T], metaclass=ABCMeta):
# pylint:disable=protected-access

def __init__(self, ranges_by_lower_bound: SortedDict) -> None:
def __init__(self, ranges_by_lower_bound: SortedDict) -> None: # type: ignore
# we store the ranges as a map sorted by their lower bound
# Note that because we enforce that there are no overlapping or connected ranges,
# this sorts the ranges by upper bound as well
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def leftmost_containing_or_above(self, lower_limit: T) -> Optional[Range[T]]:
# an AboveValue cut corresponds to a closed upper interval, which catches containment
# as desired
# I have no idea why mypy is asking for an explicit type assignment here
limit_as_bound: _AboveValue = _AboveValue(lower_limit)
limit_as_bound: _AboveValue[Any] = _AboveValue(lower_limit)

# insertion index into the sorted list of sets
idx = sorted_dict.bisect_left(limit_as_bound)
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def immutablerangemap(


# utility functions for SortedDict to give it an interface more like Java's NavigableMap
def _value_below(sorted_dict: SortedDict, key: T) -> Optional[Any]:
def _value_below(sorted_dict: SortedDict, key: T) -> Optional[Any]: # type: ignore
"""
Get item for greatest key strictly less than the given key

Expand All @@ -1426,7 +1426,7 @@ def _value_below(sorted_dict: SortedDict, key: T) -> Optional[Any]:
return None


def _value_at_or_below(sorted_dict: SortedDict, key: T) -> Optional[Any]:
def _value_at_or_below(sorted_dict: SortedDict, key: T) -> Optional[Any]: # type: ignore
"""
Get item for greatest key less than or equal to a given key.

Expand All @@ -1445,7 +1445,7 @@ def _value_at_or_below(sorted_dict: SortedDict, key: T) -> Optional[Any]:
return sorted_dict[key]


def _value_at_or_above(sorted_dict: SortedDict, key: T) -> Optional[Any]:
def _value_at_or_above(sorted_dict: SortedDict, key: T) -> Optional[Any]: # type: ignore
if not sorted_dict:
return None
idx = sorted_dict.bisect_left(key)
Expand All @@ -1455,7 +1455,7 @@ def _value_at_or_above(sorted_dict: SortedDict, key: T) -> Optional[Any]:
return sorted_dict[sorted_dict.keys()[idx]]


def _clear(
def _clear( # type: ignore
sorted_dict: SortedDict, start_key_inclusive: T, stop_key_exclusive: T
) -> None:
# we copy to a list first in case sorted_dict is not happy with modification during iteration
Expand Down
2 changes: 1 addition & 1 deletion vistautils/scripts/directory_to_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def key_function_from_params(params: Parameters) -> Callable[[Path], str]:
elif key_function_string == STRIP_ONE_EXTENSION:
return strip_one_extension_key_function
else:
raise NotImplementedError(f"Unknown key function %s", key_function_string)
raise NotImplementedError(f"Unknown key function {key_function_string}")


IDENTITY = "identity"
Expand Down
4 changes: 2 additions & 2 deletions vistautils/span.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional, Sized, Tuple, TypeVar, Union
from typing import Any, Iterable, Optional, Sized, Tuple, TypeVar, Union

from attr import attrib, attrs, validators

Expand Down Expand Up @@ -129,7 +129,7 @@ def __repr__(self):


class HasSpan(Protocol):
__slots__: tuple = ()
__slots__: Tuple[Any, ...] = ()
span: Span

@property
Expand Down