Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Implement MSC3706: partial state in /send_join response #11967

Merged
merged 3 commits into from
Feb 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11967.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental implementation of [MSC3706](https:/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs):
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)

# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
69 changes: 66 additions & 3 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -64,7 +65,7 @@
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
Expand Down Expand Up @@ -645,19 +646,40 @@ async def on_invite_request(
return {"event": ret_pdu.get_pdu_json(time_now)}

async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
self,
origin: str,
content: JsonDict,
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)

prev_state_ids = await context.get_prev_state_ids()
state_event_ids = prev_state_ids.values()

state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are basically so that the joining homeserver has some other homeservers to contact to fetch events if we were to just vanish, right?
(I guess also because the joining homeserver might need to send events and for that, it needs to know which servers are in the room — usually this is understood from the membership events but it won't have them all yet under this new scheme.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess also because the joining homeserver might need to send events and for that, it needs to know which servers are in the room

yeah, this is the main driver right now.

if caller_supports_partial_state:
state_event_ids = _get_event_ids_for_partial_state_join(
event, prev_state_ids
)
servers_in_room = await self.state.get_hosts_in_room_at_events(
room_id, event_ids=event.prev_event_ids()
)
else:
state_event_ids = prev_state_ids.values()
servers_in_room = None

auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)

# if the caller has opted in, we can omit any auth_chain events which are
# already in state_event_ids
if caller_supports_partial_state:
auth_chain_event_ids.difference_update(state_event_ids)

auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
state_events = await self.store.get_events_as_list(state_event_ids)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_events_as_list is equivalent to get_events, except that it doesn't build a dict of the results. Since we only ever used the .values of the result, the dict-building was entirely redundant. (The only thing it could plausibly have done was dedup the results, but since state_event_ids is a set anyway, there couldn't have been any dups in the first place)


Expand All @@ -671,7 +693,12 @@ async def on_send_join_request(
"event": event_json,
"state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
}

if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)

return resp

async def on_make_leave_request(
Expand Down Expand Up @@ -1347,3 +1374,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
# error.
logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))


def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
) -> Collection[str]:
"""Calculate state to be retuned in a partial_state send_join

Args:
join_event: the join event being send_joined
prev_state_ids: the event ids of the state before the join

Returns:
the event ids to be returned
"""

# return all non-member events
state_event_ids = {
event_id
for (event_type, state_key), event_id in prev_state_ids.items()
if event_type != EventTypes.Member
}

# we also need the current state of the current user (it's going to
# be an auth event for the new join, so we may as well return it)
current_membership_event_id = prev_state_ids.get(
(EventTypes.Member, join_event.state_key)
)
if current_membership_event_id is not None:
state_event_ids.add(current_membership_event_id)

# TODO: return a few more members:
# - those with invites
# - those that are kicked? / banned
Comment on lines +1408 to +1410
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interested in more rationale/elaboration here, if you have any more thoughts.
I guess I can see invites and bans because those are useful for validating other users' join events later.
However, due to e.g. annoying spammers, kicks and bans could possibly be a large amount of the state events for some rooms.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, MSC2775 lists a whole bunch of members we might want to return. I'm unconvinced by some of them, so this is here just as a thing to return to.


return state_event_ids
20 changes: 19 additions & 1 deletion synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):

PREFIX = FEDERATION_V2_PREFIX

def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._msc3706_enabled = hs.config.experimental.msc3706_enabled

async def on_PUT(
self,
origin: str,
Expand All @@ -422,7 +432,15 @@ async def on_PUT(
) -> Tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't even know who Paul is! :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's @leonerd :). He did some contracting for us back in the pre-NVL days. Which tells us how long this particular TODO has gone undone :'(

# match those given in content
result = await self.handler.on_send_join_request(origin, content, room_id)

partial_state = False
if self._msc3706_enabled:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
)
return 200, result


Expand Down
51 changes: 51 additions & 0 deletions tests/federation/test_federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,57 @@ def test_send_join(self):
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

@override_config({"experimental_features": {"msc3706_enabled": True}})
def test_send_join_partial_state(self):
"""When MSC3706 support is enabled, /send_join should return partial state"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)

join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)

# expect a reduced room state
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
],
)

# the auth chain should not include anything already in "state"
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.member", "@kermit:test"),
],
)

# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")


def _create_acl_event(content):
return make_event_from_dict(
Expand Down