Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implement MSC3706: partial state in /send_join response
Browse files Browse the repository at this point in the history
  • Loading branch information
richvdh committed Feb 11, 2022
1 parent 3ecb588 commit c2505f6
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 4 deletions.
1 change: 1 addition & 0 deletions changelog.d/11967.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental implementation of [MSC3706](https:/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ def read_config(self, config: JsonDict, **kwargs):
self.msc2409_to_device_messages_enabled: bool = experimental.get(
"msc2409_to_device_messages_enabled", False
)

# MSC3706 (server-side support for partial state in /send_join responses)
self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
69 changes: 66 additions & 3 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -64,7 +65,7 @@
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
Expand Down Expand Up @@ -645,19 +646,40 @@ async def on_invite_request(
return {"event": ret_pdu.get_pdu_json(time_now)}

async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
self,
origin: str,
content: JsonDict,
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)

prev_state_ids = await context.get_prev_state_ids()
state_event_ids = prev_state_ids.values()

state_event_ids: Collection[str]
servers_in_room: Optional[Collection[str]]
if caller_supports_partial_state:
state_event_ids = _get_event_ids_for_partial_state_join(
event, prev_state_ids
)
servers_in_room = await self.state.get_hosts_in_room_at_events(
room_id, event_ids=event.prev_event_ids()
)
else:
state_event_ids = prev_state_ids.values()
servers_in_room = None

auth_chain_event_ids = await self.store.get_auth_chain_ids(
room_id, state_event_ids
)

# if the caller has opted in, we can omit any auth_chain events which are
# already in state_event_ids
if caller_supports_partial_state:
auth_chain_event_ids.difference_update(state_event_ids)

auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
state_events = await self.store.get_events_as_list(state_event_ids)

Expand All @@ -671,7 +693,12 @@ async def on_send_join_request(
"event": event_json,
"state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
}

if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)

return resp

async def on_make_leave_request(
Expand Down Expand Up @@ -1347,3 +1374,39 @@ async def on_query(self, query_type: str, args: dict) -> JsonDict:
# error.
logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))


def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
) -> Collection[str]:
"""Calculate state to be retuned in a partial_state send_join
Args:
join_event: the join event being send_joined
prev_state_ids: the event ids of the state before the join
Returns:
the event ids to be returned
"""

# return all non-member events
state_event_ids = {
event_id
for (event_type, state_key), event_id in prev_state_ids.items()
if event_type != EventTypes.Member
}

# we also need the current state of the current user (it's going to
# be an auth event for the new join, so we may as well return it)
current_membership_event_id = prev_state_ids.get(
(EventTypes.Member, join_event.state_key)
)
if current_membership_event_id is not None:
state_event_ids.add(current_membership_event_id)

# TODO: return a few more members:
# - those with invites
# - those that are kicked? / banned

return state_event_ids
20 changes: 19 additions & 1 deletion synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):

PREFIX = FEDERATION_V2_PREFIX

def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._msc3706_enabled = hs.config.experimental.msc3706_enabled

async def on_PUT(
self,
origin: str,
Expand All @@ -422,7 +432,15 @@ async def on_PUT(
) -> Tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
result = await self.handler.on_send_join_request(origin, content, room_id)

partial_state = False
if self._msc3706_enabled:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
)
return 200, result


Expand Down
51 changes: 51 additions & 0 deletions tests/federation/test_federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,57 @@ def test_send_join(self):
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

@override_config({"experimental_features": {"msc3706_enabled": True}})
def test_send_join_partial_state(self):
"""When MSC3706 support is enabled, /send_join should return partial state"""
joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
join_result = self._make_join(joining_user)

join_event_dict = join_result["event"]
add_hashes_and_signatures(
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
join_event_dict,
signature_name=self.OTHER_SERVER_NAME,
signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
)
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEquals(channel.code, 200, channel.json_body)

# expect a reduced room state
returned_state = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
]
self.assertCountEqual(
returned_state,
[
("m.room.create", ""),
("m.room.power_levels", ""),
("m.room.join_rules", ""),
("m.room.history_visibility", ""),
],
)

# the auth chain should not include anything already in "state"
returned_auth_chain_events = [
(ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
]
self.assertCountEqual(
returned_auth_chain_events,
[
("m.room.member", "@kermit:test"),
],
)

# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")


def _create_acl_event(content):
return make_event_from_dict(
Expand Down

0 comments on commit c2505f6

Please sign in to comment.