Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Re-use work of getting state for a given state_group (_get_state_groups_from_groups) #15617

1 change: 1 addition & 0 deletions changelog.d/15617.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make `/messages` faster by efficiently grabbing state out of database whenever we have to backfill and process new events.
134 changes: 115 additions & 19 deletions synapse/storage/databases/state/bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, Union

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
Expand Down Expand Up @@ -89,6 +89,18 @@ def _get_state_groups_from_groups_txn(
groups: List[int],
state_filter: Optional[StateFilter] = None,
) -> Mapping[int, StateMap[str]]:
"""
Given a number of state groups, fetch the latest state for each group.

Args:
txn: The transaction object.
groups: The given state groups that you want to fetch the latest state for.
state_filter: The state filter to apply the state we fetch state from the database.

Returns:
Map from state_group to a StateMap at that point.
"""

state_filter = state_filter or StateFilter.all()

results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
Expand All @@ -98,24 +110,49 @@ def _get_state_groups_from_groups_txn(
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")

# The below query walks the state_group tree so that the "state"
# The query below walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE sgs(state_group) AS (
VALUES(?::bigint)
WITH RECURSIVE sgs(state_group, state_group_reached) AS (
VALUES(?::bigint, NULL::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, sgs s
WHERE s.state_group = e.state_group
SELECT
prev_state_group,
CASE
/* Specify state_groups we have already done the work for */
WHEN @prev_state_group IN (%s /* state_groups_we_have_already_fetched_string */) THEN prev_state_group
ELSE NULL
END AS state_group_reached
FROM
state_group_edges e, sgs s
WHERE
s.state_group = e.state_group
/* Stop when we connect up to another state_group that we already did the work for */
AND s.state_group_reached IS NULL
)
%s
%s /* overall_select_clause */
"""

overall_select_query_args: List[Union[int, str]] = []
# Make sure we always have a row that tells us if we linked up to another
# state_group chain that we already processed (indicated by
# `state_group_reached`) regardless of whether we find any state according
# to the state_filter.
#
# We use a `UNION ALL` to make sure it is always the first row returned.
# `UNION` will merge and sort in with the rows from the next query
# otherwise.
overall_select_clause = """
(
SELECT NULL, NULL, NULL, state_group_reached
FROM sgs
ORDER BY state_group ASC
LIMIT 1
) UNION ALL (%s /* main_select_clause */)
"""

# This is an optimization to create a select clause per-condition. This
# makes the query planner a lot smarter on what rows should pull out in the
Expand Down Expand Up @@ -154,7 +191,7 @@ def _get_state_groups_from_groups_txn(
f"""
(
SELECT DISTINCT ON (type, state_key)
type, state_key, event_id
type, state_key, event_id, state_group
FROM state_groups_state
INNER JOIN sgs USING (state_group)
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
WHERE {where_clause}
Expand All @@ -163,7 +200,7 @@ def _get_state_groups_from_groups_txn(
"""
)

overall_select_clause = " UNION ".join(select_clause_list)
main_select_clause = " UNION ".join(select_clause_list)
else:
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
Expand All @@ -173,25 +210,83 @@ def _get_state_groups_from_groups_txn(

overall_select_query_args.extend(where_args)

overall_select_clause = f"""
main_select_clause = f"""
SELECT DISTINCT ON (type, state_key)
type, state_key, event_id
type, state_key, event_id, state_group
FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM sgs
) {where_clause}
ORDER BY type, state_key, state_group DESC
"""

for group in groups:
# We can sort from least to greatest state_group and re-use the work from a
# lesser state_group for a greater one if we see that the edge chain links
# up.
#
# What this means in practice is that if we fetch the latest state for
# `state_group = 20`, and then we want `state_group = 30`, it will traverse
# down the edge chain to `20`, see that we linked up to `20` and bail out
# early and re-use the work we did for `20`. This can have massive savings
# in rooms like Matrix HQ where the edge chain is 88k events long and
# fetching the mostly-same chain over and over isn't very efficient.
sorted_groups = sorted(groups)
state_groups_we_have_already_fetched: Set[int] = {
# We default to `[-1]` just to fill in the query with something that
# will have no effect but not bork our query when it would be empty
# otherwise
-1
}
for group in sorted_groups:
args: List[Union[int, str]] = [group]
args.extend(state_groups_we_have_already_fetched)
args.extend(overall_select_query_args)

txn.execute(sql % (overall_select_clause,), args)
state_groups_we_have_already_fetched_string = ", ".join(
["?::bigint"] * len(state_groups_we_have_already_fetched)
)

txn.execute(
sql
% (
state_groups_we_have_already_fetched_string,
overall_select_clause % (main_select_clause,),
),
args,
)

# The first row is always our special `state_group_reached` row which
# tells us if we linked up to any other existing state_group that we
# already fetched and if so, which one we linked up to (see the `UNION
# ALL` above which drives this special row)
first_row = txn.fetchone()
if first_row:
_, _, _, state_group_reached = first_row

partial_state_map_for_state_group: MutableStateMap[str] = {}
for row in txn:
typ, state_key, event_id = row
typ, state_key, event_id, _state_group = row
key = (intern_string(typ), intern_string(state_key))
results[group][key] = event_id
partial_state_map_for_state_group[key] = event_id

# If we see a state_group edge link to a previous state_group that we
# already fetched from the database, link up the base state to the
# partial state we retrieved from the database to build on top of.
if state_group_reached in results:
resultant_state_map = dict(results[state_group_reached])
resultant_state_map.update(partial_state_map_for_state_group)

results[group] = resultant_state_map
else:
# It's also completely normal for us not to have a previous
# state_group to build on top of if this is the first group being
# processed or we are processing a bunch of groups from different
# rooms which of course will never link together (competely
# different DAGs).
results[group] = partial_state_map_for_state_group

state_groups_we_have_already_fetched.add(group)

Copy link
Contributor Author

@MadLittleMods MadLittleMods May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After

[...]

In the sample request above, simulating how the app would make the queries with the changes in this PR, we don't actually see that much benefit because it turns out that there isn't much state_group sharing amongst events. That 14s SQL query is just fetching the state at each of the prev_events of one event and only one seems to share state_groups.

[...]

Initially, I thought the lack of sharing was quite strange but this is because of the state_group snapshotting feature where if an event requires more than 100 hops on state_event_edges, then a Synaspe will create a new state_group with a snapshot of all of that state. It seems like this isn't done very efficiently though. Relevant docs.

And it turns out the event is an org.matrix.dummy_event which Synapse automatically puts in the DAG to resolve outstanding forward extremities and these events aren't even shown to clients so we don't even need to waste time waiting for them to backfill. Tracked by #15632

Generally, I think this PR could bring great gains in conjunction to running some sort of state compressor over the database to get a lot more sharing. In addition to trying to fix the online state_group snapshotting logic to be smarter. I don't know how the existing state_compressors work but I imagine we could create snapshots and bucket for years -> months -> weeks -> days -> hours -> individual events and create new state_group chains which utilize these from biggest to smallest to get maximal sharing.

-- "After" section of the PR description

This PR hasn't made as big of an impact as I thought it would for that type of request. Are we still interested in a change like this? It may work well for sequential events that we backfill.

It seems like our state_group sharing is realllly sub-par and the way that state_groups can only have a max of 100 hops puts an upper limit on how much gain this PR can give. I didn't anticipate that's how state_groups worked and thought it was one state_group per-state-change which it is until it starts doing snapshots.

Maybe it's more interesting to improve our state_group logic to be much smarter first and we could re-visit something like this. Or look into the state compressor stuff to optimize our backlog which would help for the Matrix Public Archive. I'm not sure if the current state compressors optimize for disk space or sharing or how inter-related those two goals are.

else:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To not complicate the diff, I've held off on applying the same treatment to SQLite.

We can iterate on this in another PR or just opt for people to use Postgres in order to see performance

max_entries_returned = state_filter.max_entries_returned()

Expand All @@ -201,8 +296,9 @@ def _get_state_groups_from_groups_txn(
if where_clause:
where_clause = " AND (%s)" % (where_clause,)

# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
# XXX: We could `WITH RECURSIVE` here since it's supported on SQLite 3.8.3
# or higher and our minimum supported version is greater than that. We just
# haven't put in the time to refactor this.
for group in groups:
next_group: Optional[int] = group

Expand Down