Skip to content

Commit

Permalink
Correctly changes to required state config
Browse files Browse the repository at this point in the history
Fixes #17698
  • Loading branch information
erikjohnston committed Oct 3, 2024
1 parent d4e3ad0 commit af9c72d
Show file tree
Hide file tree
Showing 4 changed files with 475 additions and 6 deletions.
207 changes: 203 additions & 4 deletions synapse/handlers/sliding_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from itertools import chain
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple
from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple

from prometheus_client import Histogram
from typing_extensions import assert_never
Expand Down Expand Up @@ -988,6 +988,10 @@ async def get_room_sync_data(
include_others=required_state_filter.include_others,
)

# The required state map to store in the room sync config, if it has
# changed.
changed_required_state_map: Optional[Mapping[str, AbstractSet[str]]] = None

# We can return all of the state that was requested if this was the first
# time we've sent the room down this connection.
room_state: StateMap[EventBase] = {}
Expand All @@ -1001,6 +1005,29 @@ async def get_room_sync_data(
else:
assert from_bound is not None

if prev_room_sync_config is not None:
# Check if there are any changes to the required state config
# that we need to handle.
changed_required_state_map, added_state_filter = (
_required_state_changes(
user.to_string(),
prev_room_sync_config,
room_sync_config,
room_state_delta_id_map,
)
)

if added_state_filter:
# Some state entries got added, so we pull out the current
# state for them. If we don't do this we'd only send down new deltas.
state_ids = await self.get_current_state_ids_at(
room_id=room_id,
room_membership_for_user_at_to_token=room_membership_for_user_at_to_token,
state_filter=added_state_filter,
to_token=to_token,
)
room_state_delta_id_map.update(state_ids)

events = await self.store.get_events(
state_filter.filter_state(room_state_delta_id_map).values()
)
Expand Down Expand Up @@ -1083,6 +1110,10 @@ async def get_room_sync_data(
prev_room_sync_config = previous_connection_state.room_configs.get(room_id)
# Record the `room_sync_config` if we're `ignore_timeline_bound` (which means
# that the `timeline_limit` has increased)
room_sync_required_state_map_to_persist = room_sync_config.required_state_map
if changed_required_state_map:
room_sync_required_state_map_to_persist = changed_required_state_map

if ignore_timeline_bound:
# FIXME: We signal the fact that we're sending down more events to
# the client by setting `unstable_expanded_timeline` to true (see
Expand All @@ -1091,7 +1122,7 @@ async def get_room_sync_data(

new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_config.required_state_map,
required_state_map=room_sync_required_state_map_to_persist,
)
elif prev_room_sync_config is not None:
# If the result is `limited` then we need to record that the
Expand Down Expand Up @@ -1120,10 +1151,14 @@ async def get_room_sync_data(
):
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_config.required_state_map,
required_state_map=room_sync_required_state_map_to_persist,
)

# TODO: Record changes in required_state.
elif changed_required_state_map is not None:
new_connection_state.room_configs[room_id] = RoomSyncConfig(
timeline_limit=room_sync_config.timeline_limit,
required_state_map=room_sync_required_state_map_to_persist,
)

else:
new_connection_state.room_configs[room_id] = room_sync_config
Expand Down Expand Up @@ -1242,3 +1277,167 @@ async def _get_bump_stamp(
return new_bump_event_pos.stream

return None


def _required_state_changes(
user_id: str,
previous_room_config: "RoomSyncConfig",
room_sync_config: RoomSyncConfig,
state_deltas: StateMap[str],
) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]:
"""Calculates the changes between the required state room config from the
previous requests compared with the current request.
This does two things. First, it calculates if we need to update the room
config due to changes to required state. Secondly, it works out which state
entries we need to pull from current state and return due to the state entry
now appearing in the required state when it previously wasn't (on top of the
state deltas).
This function tries to ensure to handle the case where a state entry is
added, removed and then added again to the required state. In that case we
only want to re-send that entry down sync if it has changed.
Returns:
A 2-tuple of updated required state config and the state filter to use
to fetch extra current state that we need to return.
"""

prev_required_state_map = previous_room_config.required_state_map
request_required_state_map = room_sync_config.required_state_map

if prev_required_state_map == request_required_state_map:
# There has been no change. Return immediately.
return None, StateFilter.none()

prev_wildcard = prev_required_state_map.get(StateValues.WILDCARD, set())
request_wildcard = request_required_state_map.get(StateValues.WILDCARD, set())

# If a event type wildcard has been added or removed we don't try and do
# anything fancy, and instead always update the effective room required
# state config to match the request.
if request_wildcard - prev_wildcard:
# Some keys were added, so we need to fetch everything
return request_required_state_map, StateFilter.all()
if prev_wildcard - request_wildcard:
# Keys were only removed, so we don't have to fetch everything.
return request_required_state_map, StateFilter.none()

# Contains updates to the required state map compared with the previous room
# config. This has the same format as `RoomSyncConfig.required_state`
changes: Dict[str, AbstractSet[str]] = {}

# The set of types/state keys that we need to fetch and return to the
# client. Passed to `StateFilter.from_types(...)`
added: List[Tuple[str, Optional[str]]] = []

# First we calculate what, if anything, has been *added*.
for event_type in (
prev_required_state_map.keys() | request_required_state_map.keys()
):
old_state_keys = prev_required_state_map.get(event_type, set())
request_state_keys = request_required_state_map.get(event_type, set())

if old_state_keys == request_state_keys:
# No change to this type
continue

if not request_state_keys - old_state_keys:
# Nothing *added*, so we skip. Removals happen below.
continue

# Always update changes to include the newly added keys
changes[event_type] = request_state_keys

# Record the new state keys to fetch for this type.
if StateValues.WILDCARD in request_state_keys:
# If we have added a wildcard then we always just fetch everything.
added.append((event_type, None))
else:
for state_key in request_state_keys - old_state_keys:
if state_key == StateValues.ME:
added.append((event_type, user_id))
elif state_key == StateValues.LAZY:
# We handle lazy loading separately, so don't need to
# explicitly add anything here.
pass
else:
added.append((event_type, state_key))

added_state_filter = StateFilter.from_types(added)

# Convert the list of state deltas to map from type to state_keys that have
# changed.
changed_types_to_state_keys: Dict[str, Set[str]] = {}
for event_type, state_key in state_deltas:
changed_types_to_state_keys.setdefault(event_type, set()).add(state_key)

# Figure out what changes we need to apply to the effective required state
# config.
for event_type, changed_state_keys in changed_types_to_state_keys.items():
old_state_keys = prev_required_state_map.get(event_type, set())
request_state_keys = request_required_state_map.get(event_type, set())

if old_state_keys == request_state_keys:
# No change.
continue

if request_state_keys - old_state_keys:
# We've expanded the set of state keys, so we just clobber the
# current set with the new set.
#
# We could also ensure that we keep entries where the state hasn't
# changed, but are no longer in the requested required state, but
# that's a sufficient edge case that we can ignore (as its only a
# performance optimization).
changes[event_type] = request_state_keys
continue

old_state_key_wildcard = StateValues.WILDCARD in old_state_keys
request_state_key_wildcard = StateValues.WILDCARD in request_state_keys

if old_state_key_wildcard != request_state_key_wildcard:
# If a wildcard has been added or removed we always update the
# required state when any state with the same type has changed.
changes[event_type] = request_state_keys
continue

old_state_key_lazy = StateValues.LAZY in old_state_keys
request_state_key_lazy = StateValues.LAZY in request_state_keys

if old_state_key_lazy != request_state_key_lazy:
# If a "$LAZY" has been added or removed we always update the
# required state for simplicity.
changes[event_type] = request_state_keys
continue

# Handle "$ME" values by replacing state keys that match the user ID.
if user_id in changed_state_keys:
changed_state_keys.discard(user_id)
changed_state_keys.add(StateValues.ME)

# At this point there are no wildcards and no additions to the set of
# state keys requested, only deletions.
#
# We only remove state keys from the effective state if they've been
# removed from the request *and* the state has changed. This ensures
# that if a client removes and then readds a state key, we only send
# down the associated current state event if its changed (rather than
# sending down the same event twice).
invalidated = (old_state_keys - request_state_keys) & changed_state_keys
if invalidated:
changes[event_type] = old_state_keys - invalidated

if changes:
# Update the required state config based on the changes.
new_required_state_map = dict(prev_required_state_map)
for event_type, state_keys in changes.items():
if state_keys:
new_required_state_map[event_type] = state_keys
else:
# Remove entries with empty state keys.
new_required_state_map.pop(event_type, None)

return new_required_state_map, added_state_filter
else:
return None, added_state_filter
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/sliding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ def _get_and_clear_connection_positions_txn(
required_state_map: Dict[int, Dict[str, Set[str]]] = {}
for row in rows:
state = required_state_map[row[0]] = {}
for event_type, state_keys in db_to_json(row[1]):
state[event_type] = set(state_keys)
for event_type, state_key in db_to_json(row[1]):
state.setdefault(event_type, set()).add(state_key)

# Get all the room configs, looking up the required state from the map
# above.
Expand Down
5 changes: 5 additions & 0 deletions synapse/types/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ def __contains__(self, key: Any) -> bool:

return False

def __bool__(self) -> bool:
if self.include_others:
return True
return bool(self.types)


_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
Expand Down
Loading

0 comments on commit af9c72d

Please sign in to comment.