Skip to content

Commit

Permalink
make_simplified_union: add caching and reduce allocations
Browse files Browse the repository at this point in the history
make_simplified_union is used in a lot of places and therefore
accounts for a significant share to typechecking time. Based
on sample metrics gathered from a large real-world codebase
we can see that:
 1. the majority of inputs are already as simple as they're
    going to get, which means we can avoid allocation extra
    lists and return the input unchanged
 2. most of the cost of `make_simplified_union` comes from
    `is_proper_subtype`
 3. `is_proper_subtype` has some caching going on under the hood
    but it only applies to `Instance`, and cache hit rate is low
    in this particular case because, as per 1) above, items are
    in fact rarely subtypes of each other

To address 1, refactor `make_simplified_union` with an optimistic
fast path that avoid unnecessary allocations.

To address 2 & 3, introduce a cache to record the result of union
simplification.

These changes are observed to yield significant improvements in
a real-world codebase: a roughly 10-20% overall speedup, with
make_simplified_union/is_proper_subtype no longer showing up as
hotspots in the py-spy profile.

For python#12526
  • Loading branch information
hugues-aff committed Apr 23, 2022
1 parent d1c0616 commit e5a41c6
Showing 1 changed file with 101 additions and 29 deletions.
130 changes: 101 additions & 29 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
since these may assume that MROs are ready.
"""

from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union
from typing import (
cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union, Callable
)
from typing_extensions import Type as TypingType
import itertools
import sys
Expand Down Expand Up @@ -336,6 +338,47 @@ def is_simple_literal(t: ProperType) -> bool:
return False


def _get_flattened_proper_types(items: Sequence[Type]) -> Sequence[ProperType]:
"""Similar to types.get_proper_types, with flattening of UnionType
Optimized to avoid allocating a new list whenever possible"""
i: int = 0
base: int = 0
n: int = len(items)

# optimistic fast path
while i < n:
t = items[i]
pt = get_proper_type(t)
if id(t) != id(pt) or isinstance(pt, UnionType):
# we need to allocate, switch to slow path
break
# simplify away any number of bottom type at the start of the input
if i == base and i+1 < n and isinstance(pt, UninhabitedType):
base += 1
i += 1

# optimistic fast path reached end of input, no need to allocate
if i == n:
return cast(Sequence[ProperType], items[base:] if base > 0 else items)

all_items = list(cast(Sequence[ProperType], items[base:i]))

while i < n:
pt = get_proper_type(items[i])
if isinstance(pt, UnionType):
all_items.extend(_get_flattened_proper_types(pt.items))
else:
all_items.append(pt)
i += 1
return all_items


_simplified_union_cache: List[Dict[Tuple[ProperType, ...], ProperType]] = [
{} for _ in range(2**3)
]


def make_simplified_union(items: Sequence[Type],
line: int = -1, column: int = -1,
*, keep_erased: bool = False,
Expand All @@ -362,17 +405,35 @@ def make_simplified_union(items: Sequence[Type],
back into a sum type. Set it to False when called by try_expanding_sum_type_
to_union().
"""
items = get_proper_types(items)

# Step 1: expand all nested unions
while any(isinstance(typ, UnionType) for typ in items):
all_items: List[ProperType] = []
for typ in items:
if isinstance(typ, UnionType):
all_items.extend(get_proper_types(typ.items))
else:
all_items.append(typ)
items = all_items
items = _get_flattened_proper_types(items)

cache_fn: Optional[Callable[[ProperType], None]] = None

# 1 or 2 elements account for the vast majority of inputs and are not worth caching:
# - they're two small for the quadratic worst-case cost of simplification to really
# manifest
# - they majority of those inputs are only triggered once
# - avoiding the extra allocations is a bigger win
if len(items) == 1:
return items[0]
elif len(items) > 2:
# NB: ideally we would use a frozenset, but that would require normalizing the
# order of entries in the simplified union, or updating the test harness to
# treat Unions as equivalent regardless of item ordering (which is particularly
# tricky when it comes to all tests using string matching on reveal_type output)
cache_key = tuple(items)
# NB: we need to maintain separate caches depending on flags that might impact
# the results of simplification
cache = _simplified_union_cache[
int(keep_erased)
| int(contract_literals) << 1
| int(state.strict_optional) << 2
]
ret = cache.get(cache_key, None)
if ret is not None:
return ret
cache_fn = lambda v: cache.__setitem__(cache_key, v) # noqa: E731

# Step 2: remove redundant unions
simplified_set = _remove_redundant_union_items(items, keep_erased)
Expand All @@ -381,13 +442,20 @@ def make_simplified_union(items: Sequence[Type],
if contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1:
simplified_set = try_contracting_literals_in_union(simplified_set)

return UnionType.make_union(simplified_set, line, column)
ret = UnionType.make_union(simplified_set, line, column)

if cache_fn:
cache_fn(ret)

return ret


def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]:
def _remove_redundant_union_items(items: Sequence[ProperType],
keep_erased: bool) -> Sequence[ProperType]:
from mypy.subtypes import is_proper_subtype

removed: Set[int] = set()
truthed: Set[int] = set()
seen: Set[Tuple[str, ...]] = set()

# NB: having a separate fast path for Union of Literal and slow path for other things
Expand All @@ -397,6 +465,7 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
for i, item in enumerate(items):
if i in removed:
continue

# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
k = simple_literal_value_key(item)
if k is not None:
Expand Down Expand Up @@ -434,20 +503,34 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
continue
# actual redundancy checks
if (
is_redundant_literal_instance(item, tj) # XXX?
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
isinstance(tj, UninhabitedType)
or (
(
not isinstance(item, Instance)
or item.last_known_value is None
or (
isinstance(tj, Instance)
and tj.last_known_value == item.last_known_value
)
)
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
)
):
# We found a redundant item in the union.
removed.add(j)
cbt = cbt or tj.can_be_true
cbf = cbf or tj.can_be_false

# if deleted subtypes had more general truthiness, use that
if not item.can_be_true and cbt:
items[i] = true_or_false(item)
truthed.add(i)
elif not item.can_be_false and cbf:
items[i] = true_or_false(item)
truthed.add(i)

return [items[i] for i in range(len(items)) if i not in removed]
if not removed and not truthed:
return items
return [true_or_false(items[i]) if i in truthed else items[i]
for i in range(len(items)) if i not in removed]


def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:
Expand Down Expand Up @@ -889,17 +972,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
return False


def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool:
if not isinstance(general, Instance) or general.last_known_value is None:
return True
if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value:
return True
if isinstance(specific, UninhabitedType):
return True

return False


def separate_union_literals(t: UnionType) -> Tuple[Sequence[LiteralType], Sequence[Type]]:
"""Separate literals from other members in a union type."""
literal_items = []
Expand Down

0 comments on commit e5a41c6

Please sign in to comment.