Skip to content

Commit

Permalink
The union type checker guarantees finding the most specific relevant …
Browse files Browse the repository at this point in the history
…type.

The original version checked the found type against the registered types, which wasn't correct because a superclass could be subscripted.
For example, if `Sequence[int]` is registered then `[0]` must resolve to `list[int]` not `list`.
Literals can also have a Union base. Fixes # 53.
  • Loading branch information
coady committed Sep 18, 2021
1 parent 59e2b59 commit fd95a14
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
13 changes: 5 additions & 8 deletions multimethod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __new__(cls, tp, *args):
if set(args) <= {object} and not (origin is tuple and args):
return origin
bases = (origin,) if type(origin) is type else ()
if origin is Literal and len(args) == 1:
bases = tuple(map(type, args))
if origin is Literal:
bases = (subtype(Union[tuple(map(type, args))]),)
namespace = {'__origin__': origin, '__args__': args}
return type.__new__(cls, str(tp), bases, namespace)

Expand Down Expand Up @@ -126,12 +126,9 @@ def get_type(self, arg) -> type:
if any(arg == param and type(arg) is type(param) for param in self.__args__):
return subtype(Literal, arg)
return type(arg)
if self.__origin__ is Union:
cls = subtype.get_type(self.__args__[0], arg)
for tp_arg in self.__args__[1:]:
if issubclass(tp_arg, cls): # find the most specific match without duplication
cls = subtype.get_type(tp_arg, arg)
return cls
if self.__origin__ is Union: # find the most specific match
tps = (subtype.get_type(tp_arg, arg) for tp_arg in self.__args__)
return functools.reduce(lambda l, r: l if issubclass(l, r) else r, tps)
if not isinstance(arg, self.__origin__): # no need to check subscripts
return type(arg)
if isinstance(arg, Iterator) or not isinstance(arg, Iterable):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def test_subtype():
assert tp.get_type([0]) == List[int]
assert tp.get_type([[]]) == List[subtype(list, Empty)]
assert tp.get_type([[0]]) == List[List[int]]
tp = subtype(Union, int, Iterable[int], list)
assert tp.get_type([0]) == List[int]
assert tp.get_type([]) == subtype(list, Empty)


def test_signature():
Expand Down Expand Up @@ -299,6 +302,7 @@ def temp(x: bool, y=0.0):
def test_literals():
from typing import Literal

assert issubclass(subtype(Literal['a', 'b']), str)
tp = subtype(Literal['a', 0])
assert issubclass(tp.get_type('a'), tp)
assert issubclass(tp.get_type(0), tp)
Expand Down

0 comments on commit fd95a14

Please sign in to comment.