Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to misc. files. #9676

Merged
merged 3 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9676.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to third party event rules and visibility modules.
5 changes: 4 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ files =
synapse/crypto,
synapse/event_auth.py,
synapse/events/builder.py,
synapse/events/validator.py,
synapse/events/spamcheck.py,
synapse/events/third_party_rules.py,
synapse/events/validator.py,
synapse/federation,
synapse/groups,
synapse/handlers,
Expand All @@ -38,6 +39,7 @@ files =
synapse/push,
synapse/replication,
synapse/rest,
synapse/secrets.py,
synapse/server.py,
synapse/server_notices,
synapse/spam_checker_api,
Expand Down Expand Up @@ -71,6 +73,7 @@ files =
synapse/util/metrics.py,
synapse/util/macaroons.py,
synapse/util/stringutils.py,
synapse/visibility.py,
tests/replication,
tests/test_utils,
tests/handlers/test_password_providers.py,
Expand Down
15 changes: 8 additions & 7 deletions synapse/events/third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union
from typing import TYPE_CHECKING, Union

from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap

if TYPE_CHECKING:
from synapse.server import HomeServer


class ThirdPartyEventRules:
"""Allows server admins to provide a Python module implementing an extra
Expand All @@ -28,7 +31,7 @@ class ThirdPartyEventRules:
behaviours.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.third_party_rules = None

self.store = hs.get_datastore()
Expand Down Expand Up @@ -95,10 +98,9 @@ async def on_create_room(
if self.third_party_rules is None:
return True

ret = await self.third_party_rules.on_create_room(
return await self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
return ret

async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str
Expand All @@ -119,10 +121,9 @@ async def check_threepid_can_be_invited(

state_events = await self._get_state_map_for_room(room_id)

ret = await self.third_party_rules.check_threepid_can_be_invited(
return await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
return ret

async def check_visibility_can_be_modified(
self, room_id: str, new_visibility: str
Expand All @@ -143,7 +144,7 @@ async def check_visibility_can_be_modified(
check_func = getattr(
self.third_party_rules, "check_visibility_can_be_modified", None
)
if not check_func or not isinstance(check_func, Callable):
if not check_func or not callable(check_func):
return True

state_events = await self._get_state_map_for_room(room_id)
Expand Down
8 changes: 4 additions & 4 deletions synapse/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import secrets

class Secrets:
def token_bytes(self, nbytes=32):
def token_bytes(self, nbytes: int = 32) -> bytes:
return secrets.token_bytes(nbytes)

def token_hex(self, nbytes=32):
def token_hex(self, nbytes: int = 32) -> str:
return secrets.token_hex(nbytes)


Expand All @@ -38,8 +38,8 @@ def token_hex(self, nbytes=32):
import os

class Secrets:
def token_bytes(self, nbytes=32):
def token_bytes(self, nbytes: int = 32) -> bytes:
return os.urandom(nbytes)

def token_hex(self, nbytes=32):
def token_hex(self, nbytes: int = 32) -> str:
return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii")
4 changes: 2 additions & 2 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def _get_state_groups_from_groups(
return self.stores.state._get_state_groups_from_groups(groups, state_filter)

async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Expand Down Expand Up @@ -485,7 +485,7 @@ async def get_state_for_events(
return {event: event_to_state[event] for event in event_ids}

async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
Expand Down
78 changes: 38 additions & 40 deletions synapse/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import operator
from typing import Dict, FrozenSet, List, Optional

from synapse.api.constants import (
AccountDataTypes,
EventTypes,
HistoryVisibility,
Membership,
)
from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id
from synapse.types import StateMap, get_domain_from_id

logger = logging.getLogger(__name__)

Expand All @@ -48,32 +49,32 @@

async def filter_events_for_client(
storage: Storage,
user_id,
events,
is_peeking=False,
always_include_ids=frozenset(),
filter_send_to_client=True,
):
user_id: str,
events: List[EventBase],
is_peeking: bool = False,
always_include_ids: FrozenSet[str] = frozenset(),
filter_send_to_client: bool = True,
) -> List[EventBase]:
"""
Check which events a user is allowed to see. If the user can see the event but its
sender asked for their data to be erased, prune the content of the event.

Args:
storage
user_id(str): user id to be checked
events(list[synapse.events.EventBase]): sequence of events to be checked
is_peeking(bool): should be True if:
user_id: user id to be checked
events: sequence of events to be checked
is_peeking: should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
always_include_ids (set(event_id)): set of event ids to specifically
always_include_ids: set of event ids to specifically
include (unless sender is ignored)
filter_send_to_client (bool): Whether we're checking an event that's going to be
filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point.

Returns:
list[synapse.events.EventBase]
The filtered events.
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
Expand All @@ -90,7 +91,7 @@ async def filter_events_for_client(
AccountDataTypes.IGNORED_USER_LIST, user_id
)

ignore_list = frozenset()
ignore_list = frozenset() # type: FrozenSet[str]
if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict):
Expand All @@ -107,19 +108,18 @@ async def filter_events_for_client(
room_id
] = await storage.main.get_retention_policy_for_room(room_id)

def allowed(event):
def allowed(event: EventBase) -> Optional[EventBase]:
"""
Args:
event (synapse.events.EventBase): event to check
event: event to check

Returns:
None|EventBase:
None if the user cannot see this event at all
None if the user cannot see this event at all

a redacted copy of the event if they can only see a redacted
version
a redacted copy of the event if they can only see a redacted
version

the original event if they can see it as normal.
the original event if they can see it as normal.
"""
# Only run some checks if these events aren't about to be sent to clients. This is
# because, if this is not the case, we're probably only checking if the users can
Expand Down Expand Up @@ -252,48 +252,46 @@ def allowed(event):

return event

# check each event: gives an iterable[None|EventBase]
# Check each event: gives an iterable of None or (a potentially modified)
# EventBase.
filtered_events = map(allowed, events)

# remove the None entries
filtered_events = filter(operator.truth, filtered_events)

# we turn it into a list before returning it.
return list(filtered_events)
# Turn it into a list and remove None entries before returning.
return [ev for ev in filtered_events if ev]
Comment on lines +259 to +260
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found this easier to read than using filter. 🤷



async def filter_events_for_server(
storage: Storage,
server_name,
events,
redact=True,
check_history_visibility_only=False,
):
server_name: str,
events: List[EventBase],
redact: bool = True,
check_history_visibility_only: bool = False,
) -> List[EventBase]:
"""Filter a list of events based on whether given server is allowed to
see them.

Args:
storage
server_name (str)
events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or
server_name
events
redact: Whether to return a redacted version of the event, or
to filter them out entirely.
check_history_visibility_only (bool): Whether to only check the
check_history_visibility_only: Whether to only check the
history visibility, rather than things like if the sender has been
erased. This is used e.g. during pagination to decide whether to
backfill or not.

Returns
list[FrozenEvent]
The filtered events.
"""

def is_sender_erased(event, erased_senders):
def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
if erased_senders and erased_senders[event.sender]:
logger.info("Sender of %s has been erased, redacting", event.event_id)
return True
return False

def check_event_is_visible(event, state):
def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool:
history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if history:
visibility = history.content.get(
Expand Down