Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add unstable /keys/claim endpoint which always returns fallback keys #15462

Merged
merged 6 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15462.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update support for [MSC3983](https:/matrix-org/matrix-spec-proposals/pull/3983) to allow always returning fallback-keys in a `/keys/claim` request.
4 changes: 3 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,9 @@ async def on_claim_client_keys(
query.append((user_id, device_id, algorithm))

log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
results = await self._e2e_keys_handler.claim_local_one_time_keys(
query, always_include_fallback_keys=False
)
clokep marked this conversation as resolved.
Show resolved Hide resolved

json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
Expand Down
13 changes: 5 additions & 8 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,7 @@ async def _check_user_exists(self, user_id: str) -> bool:

async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
]:
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from application services.

Users which are exclusively owned by an application service are sent a
Expand All @@ -856,7 +854,7 @@ async def claim_e2e_one_time_keys(

Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A map of user ID -> a map device ID -> a map of key ID -> JSON.

A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
Expand Down Expand Up @@ -897,12 +895,11 @@ async def claim_e2e_one_time_keys(
)

# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list
# and the caller combines them.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
# require exclusive control over the users, which is the outermost key).
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for success, result in results:
if success:
claimed_keys.append(result[0])
claimed_keys.update(result[0])
missing.extend(result[1])

return claimed_keys, missing
Expand Down
70 changes: 63 additions & 7 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ async def on_federation_query_client_keys(
return ret

async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]]
self,
local_query: List[Tuple[str, str, str]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.

Expand All @@ -573,6 +575,7 @@ async def claim_local_one_time_keys(

Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
always_include_fallback_keys: True to always include fallback keys.

Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
Expand All @@ -583,24 +586,73 @@ async def claim_local_one_time_keys(
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
# TODO Should this query for fallback keys of uploaded OTKs if
# always_include_fallback_keys is True? The MSC is ambiguous.
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []
appservice_results = {}

# Calculate which user ID / device ID / algorithm tuples to get fallback
# keys for. This can be either only missing results *or* all results
# (which don't already have a fallback key).
if always_include_fallback_keys:
# Build the fallback query as any part of the original query where
# the appservice didn't respond with a fallback key.
fallback_query = []

# Iterate each item in the original query and search the results
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
for key_id, key_json in as_result.items():
if key_id.startswith(f"{algorithm}:"):
# A OTK or fallback key was found for this query.
found_otk = True
# A fallback key was found for this query, no need to
# query further.
if key_json.get("fallback", False):
break

else:
# No fallback key was found from appservices, query for it.
# Only mark the fallback key as used if no OTK was found
# (from either the database or appservices).
mark_as_used = not found_otk and not any(
key_id.startswith(f"{algorithm}:")
for key_id in otk_results.get(user_id, {})
.get(device_id, {})
.keys()
)
fallback_query.append((user_id, device_id, algorithm, mark_as_used))

else:
# All fallback keys get marked as used.
fallback_query = [
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm in not_found
]

# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)

# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results)
return (otk_results, appservice_results, fallback_results)

@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
self,
query: Dict[str, Dict[str, Dict[str, str]]],
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
Expand All @@ -617,15 +669,19 @@ async def claim_one_time_keys(
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))

results = await self.claim_local_one_time_keys(local_query)
results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
)

# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})

# Remote failures.
failures: Dict[str, JsonDict] = {}
Expand Down
18 changes: 17 additions & 1 deletion synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import logging
import re
from typing import TYPE_CHECKING, Any, Optional, Tuple

from synapse.api.errors import InvalidAPICallError, SynapseError
Expand Down Expand Up @@ -283,15 +284,28 @@ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self._always_include_fallback_keys = False
clokep marked this conversation as resolved.
Show resolved Hide resolved

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
result = await self.e2e_keys_handler.claim_one_time_keys(
body,
timeout,
always_include_fallback_keys=self._always_include_fallback_keys,
)
return 200, result


class UnstableOneTimeKeyServlet(OneTimeKeyServlet):
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]

def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._always_include_fallback_keys = True


class SigningKeyUploadServlet(RestServlet):
"""
POST /keys/device_signing/upload HTTP/1.1
Expand Down Expand Up @@ -394,6 +408,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
if hs.config.experimental.msc3983_appservice_otk_claims:
UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)
9 changes: 5 additions & 4 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,18 +1149,19 @@ def _claim_e2e_one_time_key_returning(
return results, missing

async def claim_e2e_fallback_keys(
self, query_list: Iterable[Tuple[str, str, str]]
self, query_list: Iterable[Tuple[str, str, str, bool]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
"""Take a list of fallback keys out of the database.

Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
query_list: An iterable of tuples of
(user ID, device ID, algorithm, whether the key should be marked as used).

Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON.
"""
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, device_id, algorithm in query_list:
for user_id, device_id, algorithm, mark_as_used in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
keyvalues={
Expand All @@ -1180,7 +1181,7 @@ async def claim_e2e_fallback_keys(
used = row["used"]

# Mark fallback key as used if not already.
if not used:
if not used and mark_as_used:
await self.db_pool.simple_update_one(
table="e2e_fallback_keys_json",
keyvalues={
Expand Down
Loading