Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Skip waiting for full state if a StateFilter does not require it #12498

Merged
merged 3 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12498.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: return subsets of room state which we already have, immediately.
63 changes: 59 additions & 4 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,6 +16,7 @@
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
Expand Down Expand Up @@ -532,6 +534,44 @@ def approx_difference(self, other: "StateFilter") -> "StateFilter":
new_all, new_excludes, new_wildcards, new_concrete_keys
)

def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
"""Check if we need to wait for full state to complete to calculate this state

If we have a state filter which is completely satisfied even with partial
state, then we don't need to await_full_state before we can return it.

Args:
is_mine_id: a callable which confirms if a given state_key matches a mxid
of a local user
"""

# TODO(faster_joins): it's not entirely clear that this is safe. In particular,
# there may be circumstances in which we return a piece of state that, once we
# resync the state, we discover is invalid. For example: if it turns out that
# the sender of a piece of state wasn't actually in the room, then clearly that
# state shouldn't have been returned.
# We should at least add some tests around this to see what happens.

# if we haven't requested membership events, then it depends on the value of
# 'include_others'
if EventTypes.Member not in self.types:
return self.include_others

# if we're looking for *all* membership events, then we have to wait
member_state_keys = self.types[EventTypes.Member]
if member_state_keys is None:
return True

# otherwise, consider whose membership we are looking for. If it's entirely
# local users, then we don't need to wait.
for state_key in member_state_keys:
if not is_mine_id(state_key):
# remote user
return True

# local users only
return False


_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
Expand All @@ -544,6 +584,7 @@ class StateGroupStorage:
"""High level interface to fetching state for event."""

def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)

Expand Down Expand Up @@ -675,7 +716,13 @@ async def get_state_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.get_state_group_for_events(event_ids)
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False

event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand All @@ -699,7 +746,9 @@ async def get_state_for_events(
return {event: event_to_state[event] for event in event_ids}

async def get_state_ids_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
Expand All @@ -716,7 +765,13 @@ async def get_state_ids_for_events(
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.get_state_group_for_events(event_ids)
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False

event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)

groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
Expand Down Expand Up @@ -802,7 +857,7 @@ async def get_state_group_for_events(
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at this event.
state at these events.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
Expand Down