Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor MSC2716 /batch_send endpoint into separate handler functions #10974

Merged
1 change: 1 addition & 0 deletions changelog.d/10974.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor [MSC2716](https:/matrix-org/matrix-doc/pull/2716) `/batch_send` mega function into smaller handler functions.
373 changes: 373 additions & 0 deletions synapse/handlers/room_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
import logging
from typing import TYPE_CHECKING, List, Tuple

from synapse.http.servlet import (
assert_params_in_dict,
)
from synapse.api.constants import EventContentFields, EventTypes
from synapse.appservice import ApplicationService
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.stringutils import random_string

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class RoomBatchHandler:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionality is the same. Just copying logic from synapse/rest/client/room_batch.py

(Complement tests pass)

"""Contains some read only APIs to get state about a room"""
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.state_store = hs.get_storage().state
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()

async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int:
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
(
most_recent_prev_event_id,
most_recent_prev_event_depth,
) = await self.store.get_max_depth_of(prev_event_ids)

# We want to insert the historical event after the `prev_event` but before the successor event
#
# We inherit depth from the successor event instead of the `prev_event`
# because events returned from `/messages` are first sorted by `topological_ordering`
# which is just the `depth` and then tie-break with `stream_ordering`.
#
# We mark these inserted historical events as "backfilled" which gives them a
# negative `stream_ordering`. If we use the same depth as the `prev_event`,
# then our historical event will tie-break and be sorted before the `prev_event`
# when it should come after.
#
# We want to use the successor event depth so they appear after `prev_event` because
# it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event.
successor_event_ids = await self.store.get_successor_events(
[most_recent_prev_event_id]
)

# If we can't find any successor events, then it's a forward extremity of
# historical messages and we can just inherit from the previous historical
# event which we can already assume has the correct depth where we want
# to insert into.
if not successor_event_ids:
depth = most_recent_prev_event_depth
else:
(
_,
oldest_successor_depth,
) = await self.store.get_min_depth_of(successor_event_ids)

depth = oldest_successor_depth

return depth

def create_insertion_event_dict(
self, sender: str, room_id: str, origin_server_ts: int
) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields
and a random batch ID.

Args:
sender: The event author MXID
room_id: The room ID that the event belongs to
origin_server_ts: Timestamp when the event was sent

Returns:
The new event dictionary to insert.
"""

next_batch_id = random_string(8)
insertion_event = {
"type": EventTypes.MSC2716_INSERTION,
"sender": sender,
"room_id": room_id,
"content": {
EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
EventContentFields.MSC2716_HISTORICAL: True,
},
"origin_server_ts": origin_server_ts,
}

return insertion_event

async def create_requester_for_user_id_from_app_service(
self, user_id: str, app_service: ApplicationService
) -> Requester:
"""Creates a new requester for the given user_id
and validates that the app service is allowed to control
the given user.

Args:
user_id: The author MXID that the app service is controlling
app_service: The app service that controls the user

Returns:
Requester object
"""

await self.auth.validate_appservice_can_control_user_id(app_service, user_id)

return create_requester(user_id, app_service=app_service)

async def getMostRecentAuthEventIdsFromEventIdList(
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
self, event_ids: List[str]
) -> List[str]:
"""Find the most recent auth event ids (derived from state events) that
allowed that message to be sent. We will use that as a base
to auth our historical messages against.
"""

(
most_recent_prev_event_id,
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids

return auth_event_ids

async def persistStateEventsAtStart(
self,
state_events_at_start: List[JsonDict],
room_id: str,
initial_auth_event_ids: List[str],
requester: Requester,
) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists
them as floating state events which don't resolve into the current room state.
They are floating because they reference a fake prev_event which doesn't connect
to the normal DAG at all.

Returns:
List of state event ID's we just persisted
"""
assert requester.app_service

state_event_ids_at_start = []
auth_event_ids = initial_auth_event_ids.copy()
for state_event in state_events_at_start:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
)

logger.debug(
"RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
state_event,
auth_event_ids,
)

event_dict = {
"type": state_event["type"],
"origin_server_ts": state_event["origin_server_ts"],
"content": state_event["content"],
"room_id": room_id,
"sender": state_event["sender"],
"state_key": state_event["state_key"],
}

# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True

# Make the state events float off on their own so we don't have a
# bunch of `@mxid joined the room` noise between each batch
fake_prev_event_id = "$" + random_string(43)

# TODO: This is pretty much the same as some other code to handle inserting state in this file
if event_dict["type"] == EventTypes.Member:
membership = event_dict["content"].get("membership", None)
event_id, _ = await self.room_member_handler.update_membership(
await self.create_requester_for_user_id_from_app_service(
state_event["sender"], requester.app_service
),
target=UserID.from_string(event_dict["state_key"]),
room_id=room_id,
action=membership,
content=event_dict["content"],
outlier=True,
prev_event_ids=[fake_prev_event_id],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
auth_event_ids=auth_event_ids.copy(),
)
else:
# TODO: Add some complement tests that adds state that is not member joins
# and will use this code path. Maybe we only want to support join state events
# and can get rid of this `else`?
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
await self.create_requester_for_user_id_from_app_service(
state_event["sender"], requester.app_service
),
event_dict,
outlier=True,
prev_event_ids=[fake_prev_event_id],
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
auth_event_ids=auth_event_ids.copy(),
)
event_id = event.event_id

state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)

return state_event_ids_at_start

async def persistHistoricalEvents(
self,
events_to_create: List[JsonDict],
room_id: str,
initial_prev_event_ids: List[str],
inherited_depth: int,
auth_event_ids: List[str],
requester: Requester,
) -> List[str]:
"""Create and persists all events provided sequentially. Handles the
complexity of creating events in chronological order so they can
reference each other by prev_event but still persists in
reverse-chronoloical order so they have the correct
(topological_ordering, stream_ordering) and sort correctly from
/messages.

Returns:
List of persisted event IDs
"""
assert requester.app_service

prev_event_ids = initial_prev_event_ids.copy()

event_ids = []
events_to_persist = []
for ev in events_to_create:
assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"])

event_dict = {
"type": ev["type"],
"origin_server_ts": ev["origin_server_ts"],
"content": ev["content"],
"room_id": room_id,
"sender": ev["sender"], # requester.user.to_string(),
"prev_events": prev_event_ids.copy(),
}

# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True

event, context = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], requester.app_service
),
event_dict,
prev_event_ids=event_dict.get("prev_events"),
auth_event_ids=auth_event_ids,
historical=True,
depth=inherited_depth,
)
logger.debug(
"RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
event,
prev_event_ids,
auth_event_ids,
)

assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender,
)

events_to_persist.append((event, context))
event_id = event.event_id

event_ids.append(event_id)
prev_event_ids = [event_id]

# Persist events in reverse-chronological order so they have the
# correct stream_ordering as they are backfilled (which decrements).
# Events are sorted by (topological_ordering, stream_ordering)
# where topological_ordering is just depth.
for (event, context) in reversed(events_to_persist):
await self.event_creation_handler.handle_new_client_event(
await self.create_requester_for_user_id_from_app_service(
event["sender"], requester.app_service
),
event=event,
context=context,
)

return event_ids

async def handleBatchOfEvents(
self,
events_to_create: List[JsonDict],
room_id: str,
batch_id_to_connect_to: str,
initial_prev_event_ids: List[str],
inherited_depth: int,
auth_event_ids: List[str],
requester: Requester,
) -> Tuple[List[str], str]:
"""
Handles creating and persisting all of the historical events as well
as insertion and batch meta events to make the batch navigable in the DAG.

Returns:
Tuple containing a list of created events and the next_batch_id
"""

# Connect this current batch to the insertion event from the previous batch
last_event_in_batch = events_to_create[-1]
batch_event = {
"type": EventTypes.MSC2716_BATCH,
"sender": requester.user.to_string(),
"room_id": room_id,
"content": {
EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
EventContentFields.MSC2716_HISTORICAL: True,
},
# Since the batch event is put at the end of the batch,
# where the newest-in-time event is, copy the origin_server_ts from
# the last event we're inserting
"origin_server_ts": last_event_in_batch["origin_server_ts"],
}
# Add the batch event to the end of the batch (newest-in-time)
events_to_create.append(batch_event)

# Add an "insertion" event to the start of each batch (next to the oldest-in-time
# event in the batch) so the next batch can be connected to this one.
insertion_event = self.create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
# Since the insertion event is put at the start of the batch,
# where the oldest-in-time event is, copy the origin_server_ts from
# the first event we're inserting
origin_server_ts=events_to_create[0]["origin_server_ts"],
)
next_batch_id = insertion_event["content"][
EventContentFields.MSC2716_NEXT_BATCH_ID
]
# Prepend the insertion event to the start of the batch (oldest-in-time)
events_to_create = [insertion_event] + events_to_create

# Create and persist all of the historical events
event_ids = await self.persistHistoricalEvents(
events_to_create=events_to_create,
room_id=room_id,
initial_prev_event_ids=initial_prev_event_ids,
inherited_depth=inherited_depth,
auth_event_ids=auth_event_ids,
requester=requester,
)

return event_ids, next_batch_id
Loading