From 4f715beebf2e7cf00c79dc6dba71a309bb9ec95a Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 28 May 2020 22:36:21 +0100 Subject: [PATCH 01/31] Refactor and comment ratelimiting. Set limits in constructor --- synapse/api/ratelimiting.py | 122 +++++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 35 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 7a049b3af734..38d744fd9405 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -13,78 +13,130 @@ # limitations under the License. from collections import OrderedDict -from typing import Any, Optional, Tuple +from typing import Any, Tuple from synapse.api.errors import LimitExceededError class Ratelimiter(object): """ - Ratelimit message sending by user. + Ratelimit actions marked by arbitrary keys. + + Args: + rate_hz: The long term number of actions that can be performed in a + second. + burst_count: How many actions that can be performed before being + limited. """ - def __init__(self): - self.message_counts = ( - OrderedDict() - ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] + def __init__(self, rate_hz: float, burst_count: int): + # A ordered dictionary keeping track of actions, when they were last + # performed and how often. Each entry is a mapping from a key of arbitrary type + # to a tuple representing: + # * How many times an action has occurred since a point in time + # * That point in time + self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int]] + self.rate_hz = rate_hz + self.burst_count = burst_count - def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + def can_do_action( + self, + key: Any, + time_now_s: int, + update: bool = True, + ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now. - rate_hz: The long term number of messages a user can send in a - second. - burst_count: How many messages the user can send before being - limited. - update (bool): Whether to update the message rates or not. This is - useful to check if a message would be allowed to be sent before - its ready to be actually sent. + time_now_s: The time now + update: Whether to count this check as performing the action Returns: - A pair of a bool indicating if they can send a message now and a - time in seconds of when they can next send a message. + A tuple containing: + * A bool indicating if they can perform the action now + * The time in seconds of when it can next be performed. + -1 if a rate_hz has not been defined for this Ratelimiter """ - self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.get( - key, (0.0, time_now_s, None) + # Remove any expired entries + self._prune_message_counts(time_now_s) + + # Check if there is an existing count entry for this key + action_count, time_start, = self.actions.get( + key, (0.0, time_now_s) ) + + # Check whether performing another action is allowed time_delta = time_now_s - time_start - sent_count = message_count - time_delta * rate_hz - if sent_count < 0: + performed_count = action_count - time_delta * self.rate_hz + if performed_count < 0: + # Allow, reset back to count 1 allowed = True time_start = time_now_s - message_count = 1.0 - elif sent_count > burst_count - 1.0: + action_count = 1.0 + elif performed_count > self.burst_count - 1.0: + # Deny, we have exceeded our burst count allowed = False else: + # We haven't reached our limit yet allowed = True - message_count += 1 + action_count += 1.0 if update: - self.message_counts[key] = (message_count, time_start, rate_hz) + self.actions[key] = (action_count, time_start) - if rate_hz > 0: - time_allowed = time_start + (message_count - burst_count + 1) / rate_hz + # Figure out the time when an action can be performed again + if self.rate_hz > 0: + time_allowed = ( + time_start + (action_count - self.burst_count + 1) / self.rate_hz + ) + + # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s else: + # This does not apply time_allowed = -1 return allowed, time_allowed - def prune_message_counts(self, time_now_s): - for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = self.message_counts[key] + def _prune_message_counts(self, time_now_s: int): + """Remove message count entries that are older than + + Args: + time_now_s: The current time + """ + # We create a copy of the key list here as the dictionary is modified during + # the loop + for key in list(self.actions.keys()): + action_count, time_start = self.actions[key] + time_delta = time_now_s - time_start - if message_count - time_delta * rate_hz > 0: + if action_count - time_delta * self.rate_hz > 0: + # XXX: Should this be a continue? break else: - del self.message_counts[key] + del self.actions[key] + + def ratelimit( + self, + key: Any, + time_now_s: int, + update: bool = True, + ): + """Checks if an action can be performed. If not, raises a LimitExceededError - def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True): + Args: + key: An arbitrary key used to classify an action + time_now_s: The current time + update: Whether to count this check as performing the action + + Raises: + LimitExceededError: If an action could not be performed, along with the time in + milliseconds until the action can be performed again + """ allowed, time_allowed = self.can_do_action( - key, time_now_s, rate_hz, burst_count, update + key, time_now_s, update ) if not allowed: From 0e6ee7ca1728e6912f9bc219849bcfc1188b5cff Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 28 May 2020 22:37:28 +0100 Subject: [PATCH 02/31] Ratelimiters are instantiated by the HomeServer class This makes it simple for tests to modify/nullify them. --- synapse/server.py | 43 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/synapse/server.py b/synapse/server.py index ca2deb49bbe4..440c6807d038 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -24,6 +24,7 @@ import abc import logging import os +from typing import Optional from twisted.mail.smtp import sendmail @@ -242,9 +243,31 @@ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwar self.clock = Clock(reactor) self.distributor = Distributor() - self.ratelimiter = Ratelimiter() - self.admin_redaction_ratelimiter = Ratelimiter() - self.registration_ratelimiter = Ratelimiter() + # The rate_hz and burst_count is overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + rate_hz=0, + burst_count=0, + ) + if config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + rate_hz=config.rc_admin_redaction.per_second, + burst_count=config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + + self.registration_ratelimiter = Ratelimiter( + rate_hz=config.rc_registration.per_second, + burst_count=config.rc_registration.burst_count, + ) + self.login_ratelimiter = Ratelimiter( + rate_hz=config.rc_login_account.per_second, + burst_count=config.rc_login_account.burst_count, + ) + self.login_failed_attempts_ratelimiter = Ratelimiter( + rate_hz=config.rc_login_failed_attempts.per_second, + burst_count=config.rc_login_failed_attempts.burst_count, + ) self.datastores = None @@ -314,15 +337,21 @@ def get_config(self): def get_distributor(self): return self.distributor - def get_ratelimiter(self): - return self.ratelimiter + def get_request_ratelimiter(self) -> Ratelimiter: + return self.request_ratelimiter - def get_registration_ratelimiter(self): + def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self): + def get_admin_redaction_ratelimiter(self) -> Optional[Ratelimiter]: return self.admin_redaction_ratelimiter + def get_login_ratelimiter(self) -> Ratelimiter: + return self.login_ratelimiter + + def get_login_failed_attempts_ratelimiter(self) -> Ratelimiter: + return self.login_failed_attempts_ratelimiter + def build_federation_client(self): return FederationClient(self) From 82eac22286c4be119d13c46e459d4e3dbcb2f59e Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 28 May 2020 22:38:26 +0100 Subject: [PATCH 03/31] Modify servlets to pull Ratelimiters from HomeServer class --- synapse/config/ratelimiting.py | 8 +++- synapse/handlers/_base.py | 56 +++++++++--------------- synapse/handlers/auth.py | 10 ++--- synapse/handlers/message.py | 1 - synapse/handlers/register.py | 2 - synapse/rest/client/v1/login.py | 19 ++------ synapse/rest/client/v2_alpha/register.py | 2 - synapse/util/ratelimitutils.py | 2 +- 8 files changed, 37 insertions(+), 63 deletions(-) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 4a3bfc435402..8e42d15fa408 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -14,9 +14,15 @@ from ._base import Config +from typing import Dict + class RateLimitConfig(object): - def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): + def __init__( + self, + config: Dict[str, float], + defaults={"per_second": 0.17, "burst_count": 3.0}, + ): self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = config.get("burst_count", defaults["burst_count"]) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 3b781d98361a..206702b6ad6b 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,7 +19,6 @@ import synapse.types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError from synapse.types import UserID logger = logging.getLogger(__name__) @@ -44,11 +43,16 @@ def __init__(self, hs): self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() - self.ratelimiter = hs.get_ratelimiter() - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() self.clock = hs.get_clock() self.hs = hs + self.ratelimiter = None + self.request_ratelimiter = hs.get_request_ratelimiter() + self._rc_message = self.hs.config.rc_message + + # If special admin redaction ratelimiting is disabled, this will be None + self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() + self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() @@ -83,48 +87,30 @@ def ratelimit(self, requester, update=True, is_admin_redaction=False): if requester.app_service and not requester.app_service.is_rate_limited(): return + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + # Check if there is a per user override in the DB. override = yield self.store.get_ratelimit_for_user(user_id) if override: - # If overriden with a null Hz then ratelimiting has been entirely + # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user if not override.messages_per_second: return messages_per_second = override.messages_per_second burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_id's clash + self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update) else: - # We default to different values if this is an admin redaction and - # the config is set - if is_admin_redaction and self.hs.config.rc_admin_redaction: - messages_per_second = self.hs.config.rc_admin_redaction.per_second - burst_count = self.hs.config.rc_admin_redaction.burst_count - else: - messages_per_second = self.hs.config.rc_message.per_second - burst_count = self.hs.config.rc_message.burst_count - - if is_admin_redaction and self.hs.config.rc_admin_redaction: - # If we have separate config for admin redactions we use a separate - # ratelimiter - allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( - user_id, - time_now, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) - else: - allowed, time_allowed = self.ratelimiter.can_do_action( - user_id, - time_now, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + # Override rate and burst count per-user + self.request_ratelimiter.rate_hz = messages_per_second + self.request_ratelimiter.burst_count = burst_count + + self.request_ratelimiter.ratelimit(user_id, time_now, update) async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 75b39e878c14..9aab4692f1e0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -108,7 +108,11 @@ def __init__(self, hs): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. - self._failed_uia_attempts_ratelimiter = Ratelimiter() + # XXX: Should this be hs.get_login_failed_attempts_ratelimiter? + self._failed_uia_attempts_ratelimiter = Ratelimiter( + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) self._clock = self.hs.get_clock() @@ -199,8 +203,6 @@ async def validate_user_via_ui_auth( self._failed_uia_attempts_ratelimiter.ratelimit( user_id, time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) @@ -216,8 +218,6 @@ async def validate_user_via_ui_auth( self._failed_uia_attempts_ratelimiter.can_do_action( user_id, time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, ) raise diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 681f92cafd86..649ca1f08a53 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -362,7 +362,6 @@ def __init__(self, hs): self.profile_handler = hs.get_profile_handler() self.event_builder_factory = hs.get_event_builder_factory() self.server_name = hs.hostname - self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a6178e74a19b..99e2b3fb2c79 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -430,8 +430,6 @@ def check_registration_ratelimit(self, address): self.ratelimiter.ratelimit( address, time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, ) def register_with_store( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d89b2e5532fa..2754a0466923 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -89,9 +89,8 @@ def __init__(self, hs): self.handlers = hs.get_handlers() self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter() - self._account_ratelimiter = Ratelimiter() - self._failed_attempts_ratelimiter = Ratelimiter() + self._account_ratelimiter = hs.get_login_ratelimiter() + self._failed_attempts_ratelimiter = hs.get_login_failed_attempts_ratelimiter() def on_GET(self, request): flows = [] @@ -129,11 +128,9 @@ def on_OPTIONS(self, request): return 200, {} async def on_POST(self, request): - self._address_ratelimiter.ratelimit( + self._account_ratelimiter.ratelimit( request.getClientIP(), time_now_s=self.hs.clock.time(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, update=True, ) @@ -206,8 +203,6 @@ async def _do_other_login(self, login_submission): self._failed_attempts_ratelimiter.ratelimit( (medium, address), time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) @@ -246,8 +241,6 @@ async def _do_other_login(self, login_submission): self._failed_attempts_ratelimiter.can_do_action( (medium, address), time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -270,8 +263,6 @@ async def _do_other_login(self, login_submission): self._failed_attempts_ratelimiter.ratelimit( qualified_user_id.lower(), time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=False, ) @@ -287,8 +278,6 @@ async def _do_other_login(self, login_submission): self._failed_attempts_ratelimiter.can_do_action( qualified_user_id.lower(), time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, update=True, ) raise @@ -326,8 +315,6 @@ async def _complete_login( self._account_ratelimiter.ratelimit( user_id.lower(), time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, update=True, ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index addd4cae1906..780060493885 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -401,8 +401,6 @@ async def on_POST(self, request): allowed, time_allowed = self.ratelimiter.can_do_action( client_addr, time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, update=False, ) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 5ca4521ce36c..e5efdfcd0266 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -43,7 +43,7 @@ def new_limiter(): self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): - """Used to ratelimit an incoming request from given host + """Used to ratelimit an incoming request from a given host Example usage: From a0ef594905424d085a149debfccb85d8f48c5919 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 28 May 2020 22:36:37 +0100 Subject: [PATCH 04/31] Update unittests --- tests/api/test_ratelimiting.py | 22 ++++++++++----------- tests/handlers/test_profile.py | 12 ++++++++--- tests/replication/slave/storage/_base.py | 10 ++++++++-- tests/rest/client/v1/test_events.py | 13 +++++++++--- tests/rest/client/v1/test_login.py | 12 +++++------ tests/rest/client/v1/test_rooms.py | 13 +++++++++--- tests/rest/client/v1/test_typing.py | 12 ++++++++--- tests/rest/client/v2_alpha/test_register.py | 8 ++++---- 8 files changed, 67 insertions(+), 35 deletions(-) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dbdd427cac22..98336a090703 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,35 +5,35 @@ class TestRatelimiter(unittest.TestCase): def test_allowed(self): - limiter = Ratelimiter() + limiter = Ratelimiter(rate_hz=0.1, burst_count=1) allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 + key="test_id", time_now_s=0 ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 + key="test_id", time_now_s=5 ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 + key="test_id", time_now_s=10 ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) def test_pruning(self): - limiter = Ratelimiter() - allowed, time_allowed = limiter.can_do_action( - key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 + limiter = Ratelimiter(rate_hz=0.1, burst_count=1) + _, _ = limiter.can_do_action( + key="test_id_1", time_now_s=0 ) - self.assertIn("test_id_1", limiter.message_counts) + self.assertIn("test_id_1", limiter.actions) - allowed, time_allowed = limiter.can_do_action( - key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 + _, _ = limiter.can_do_action( + key="test_id_2", time_now_s=10 ) - self.assertNotIn("test_id_1", limiter.message_counts) + self.assertNotIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 8aa56f149699..a34c70f5a715 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -55,11 +55,17 @@ def register_query_handler(query_type, handler): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + request_ratelimiter=NonCallableMock( + spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + ), + login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + self.request_ratelimiter = hs.get_request_ratelimiter() + self.request_ratelimiter.can_do_action.return_value = (True, 0) + + self.login_ratelimiter = hs.get_login_ratelimiter() + self.login_ratelimiter.can_do_action.return_value = (True, 0) self.store = hs.get_datastore() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 32cb04645f91..928a3da22375 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -23,10 +23,16 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + request_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), + login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - hs.get_ratelimiter().can_do_action.return_value = (True, 0) + # Prevent ratelimiting + self.request_ratelimiter = hs.get_request_ratelimiter() + self.request_ratelimiter.can_do_action.return_value = (True, 0) + + self.login_ratelimiter = hs.get_login_ratelimiter() + self.login_ratelimiter.can_do_action.return_value = (True, 0) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index b54b06482b13..1c42f4063f75 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -41,10 +41,17 @@ def make_homeserver(self, reactor, clock): config["auto_join_rooms"] = [] hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) + config=config, + request_ratelimiter=NonCallableMock( + spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + ), + login_ratelimiter = NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + self.request_ratelimiter = hs.get_request_ratelimiter() + self.request_ratelimiter.can_do_action.return_value = (True, 0) + + self.login_ratelimiter = hs.get_login_ratelimiter() + self.login_ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index eb8f6264fdb4..c01738ed6918 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -36,8 +36,8 @@ def make_homeserver(self, reactor, clock): return self.hs def test_POST_ratelimiting_per_address(self): - self.hs.config.rc_login_address.burst_count = 5 - self.hs.config.rc_login_address.per_second = 0.17 + self.hs.get_login_ratelimiter().burst_count = 5 + self.hs.get_login_ratelimiter().rate_hz = 0.17 # Create different users so we're sure not to be bothered by the per-user # ratelimiter. @@ -78,8 +78,8 @@ def test_POST_ratelimiting_per_address(self): self.assertEquals(channel.result["code"], b"200", channel.result) def test_POST_ratelimiting_per_account(self): - self.hs.config.rc_login_account.burst_count = 5 - self.hs.config.rc_login_account.per_second = 0.17 + self.hs.get_login_ratelimiter().burst_count = 5 + self.hs.get_login_ratelimiter().rate_hz = 0.17 self.register_user("kermit", "monkey") @@ -117,8 +117,8 @@ def test_POST_ratelimiting_per_account(self): self.assertEquals(channel.result["code"], b"200", channel.result) def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.config.rc_login_failed_attempts.burst_count = 5 - self.hs.config.rc_login_failed_attempts.per_second = 0.17 + self.hs.get_login_failed_attempts_ratelimiter().burst_count = 5 + self.hs.get_login_failed_attempts_ratelimiter().rate_hz = 0.17 self.register_user("kermit", "monkey") diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 7dd86d0c27bd..a07884f20d36 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -49,10 +49,17 @@ def make_homeserver(self, reactor, clock): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + request_ratelimiter=NonCallableMock( + spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + ), + login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + self.request_ratelimiter = self.hs.get_request_ratelimiter() + self.request_ratelimiter.can_do_action.return_value = (True, 0) + self.request_ratelimiter.rate_hz = Mock() + + self.login_ratelimiter = self.hs.get_login_ratelimiter() + self.login_ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 4bc3aaf02d2e..30bb6bd34a66 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -42,13 +42,19 @@ def make_homeserver(self, reactor, clock): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + request_ratelimiter=NonCallableMock( + spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + ), + login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) self.event_source = hs.get_event_sources().sources["typing"] - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + self.request_ratelimiter = hs.get_request_ratelimiter() + self.request_ratelimiter.can_do_action.return_value = (True, 0) + + self.login_ratelimiter = hs.get_login_ratelimiter() + self.login_ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 5637ce20907f..c64b65889276 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -147,8 +147,8 @@ def test_POST_disabled_guest_registration(self): self.assertEquals(channel.json_body["error"], "Guest access is disabled") def test_POST_ratelimiting_guest(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 + self.hs.get_registration_ratelimiter().burst_count = 5 + self.hs.get_registration_ratelimiter().rate_hz = 0.17 for i in range(0, 6): url = self.url + b"?kind=guest" @@ -169,8 +169,8 @@ def test_POST_ratelimiting_guest(self): self.assertEquals(channel.result["code"], b"200", channel.result) def test_POST_ratelimiting(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 + self.hs.get_registration_ratelimiter().burst_count = 5 + self.hs.get_registration_ratelimiter().rate_hz = 0.17 for i in range(0, 6): params = { From 6a07c2d9ad4bcc35627f6d3f48941efd58c9a62d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 28 May 2020 22:43:58 +0100 Subject: [PATCH 05/31] lint --- synapse/api/ratelimiting.py | 18 ++++------------- synapse/config/ratelimiting.py | 4 ++-- synapse/handlers/auth.py | 8 ++------ synapse/handlers/register.py | 3 +-- synapse/rest/client/v1/login.py | 25 ++++++------------------ synapse/rest/client/v2_alpha/register.py | 4 +--- synapse/server.py | 5 +---- tests/api/test_ratelimiting.py | 20 +++++-------------- tests/handlers/test_profile.py | 1 + tests/replication/slave/storage/_base.py | 4 +++- tests/rest/client/v1/test_events.py | 3 ++- tests/rest/client/v1/test_rooms.py | 1 + tests/rest/client/v1/test_typing.py | 1 + 13 files changed, 30 insertions(+), 67 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 38d744fd9405..13fff302fe67 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -40,10 +40,7 @@ def __init__(self, rate_hz: float, burst_count: int): self.burst_count = burst_count def can_do_action( - self, - key: Any, - time_now_s: int, - update: bool = True, + self, key: Any, time_now_s: int, update: bool = True, ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? @@ -62,9 +59,7 @@ def can_do_action( self._prune_message_counts(time_now_s) # Check if there is an existing count entry for this key - action_count, time_start, = self.actions.get( - key, (0.0, time_now_s) - ) + action_count, time_start, = self.actions.get(key, (0.0, time_now_s)) # Check whether performing another action is allowed time_delta = time_now_s - time_start @@ -119,10 +114,7 @@ def _prune_message_counts(self, time_now_s: int): del self.actions[key] def ratelimit( - self, - key: Any, - time_now_s: int, - update: bool = True, + self, key: Any, time_now_s: int, update: bool = True, ): """Checks if an action can be performed. If not, raises a LimitExceededError @@ -135,9 +127,7 @@ def ratelimit( LimitExceededError: If an action could not be performed, along with the time in milliseconds until the action can be performed again """ - allowed, time_allowed = self.can_do_action( - key, time_now_s, update - ) + allowed, time_allowed = self.can_do_action(key, time_now_s, update) if not allowed: raise LimitExceededError( diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 8e42d15fa408..2dd94bae2bb2 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config - from typing import Dict +from ._base import Config + class RateLimitConfig(object): def __init__( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 9aab4692f1e0..089c94f8b679 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -201,9 +201,7 @@ async def validate_user_via_ui_auth( # Check if we should be ratelimited due to too many previous failed attempts self._failed_uia_attempts_ratelimiter.ratelimit( - user_id, - time_now_s=self._clock.time(), - update=False, + user_id, time_now_s=self._clock.time(), update=False, ) # build a list of supported flows @@ -216,9 +214,7 @@ async def validate_user_via_ui_auth( except LoginError: # Update the ratelimite to say we failed (`can_do_action` doesn't raise). self._failed_uia_attempts_ratelimiter.can_do_action( - user_id, - time_now_s=self._clock.time(), - update=True, + user_id, time_now_s=self._clock.time(), update=True, ) raise diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 99e2b3fb2c79..ce18b33a634b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -428,8 +428,7 @@ def check_registration_ratelimit(self, address): time_now = self.clock.time() self.ratelimiter.ratelimit( - address, - time_now_s=time_now, + address, time_now_s=time_now, ) def register_with_store( diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 2754a0466923..19c392849a3e 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -16,7 +16,6 @@ import logging from synapse.api.errors import Codes, LoginError, SynapseError -from synapse.api.ratelimiting import Ratelimiter from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -129,9 +128,7 @@ def on_OPTIONS(self, request): async def on_POST(self, request): self._account_ratelimiter.ratelimit( - request.getClientIP(), - time_now_s=self.hs.clock.time(), - update=True, + request.getClientIP(), time_now_s=self.hs.clock.time(), update=True, ) login_submission = parse_json_object_from_request(request) @@ -201,9 +198,7 @@ async def _do_other_login(self, login_submission): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. self._failed_attempts_ratelimiter.ratelimit( - (medium, address), - time_now_s=self._clock.time(), - update=False, + (medium, address), time_now_s=self._clock.time(), update=False, ) # Check for login providers that support 3pid login types @@ -239,9 +234,7 @@ async def _do_other_login(self, login_submission): # this code path, which is fine as then the per-user ratelimit # will kick in below. self._failed_attempts_ratelimiter.can_do_action( - (medium, address), - time_now_s=self._clock.time(), - update=True, + (medium, address), time_now_s=self._clock.time(), update=True, ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -261,9 +254,7 @@ async def _do_other_login(self, login_submission): # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - update=False, + qualified_user_id.lower(), time_now_s=self._clock.time(), update=False, ) try: @@ -276,9 +267,7 @@ async def _do_other_login(self, login_submission): # exception and masking the LoginError. The actual ratelimiting # should have happened above. self._failed_attempts_ratelimiter.can_do_action( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - update=True, + qualified_user_id.lower(), time_now_s=self._clock.time(), update=True, ) raise @@ -313,9 +302,7 @@ async def _complete_login( # too often. This happens here rather than before as we don't # necessarily know the user before now. self._account_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - update=True, + user_id.lower(), time_now_s=self._clock.time(), update=True, ) if create_non_existant_users: diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 780060493885..8567cbcab3a8 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -399,9 +399,7 @@ async def on_POST(self, request): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, - time_now_s=time_now, - update=False, + client_addr, time_now_s=time_now, update=False, ) if not allowed: diff --git a/synapse/server.py b/synapse/server.py index 440c6807d038..fc39b57135fa 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -244,10 +244,7 @@ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwar self.clock = Clock(reactor) self.distributor = Distributor() # The rate_hz and burst_count is overridden on a per-user basis - self.request_ratelimiter = Ratelimiter( - rate_hz=0, - burst_count=0, - ) + self.request_ratelimiter = Ratelimiter(rate_hz=0, burst_count=0,) if config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( rate_hz=config.rc_admin_redaction.per_second, diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 98336a090703..973c7e007c51 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -6,34 +6,24 @@ class TestRatelimiter(unittest.TestCase): def test_allowed(self): limiter = Ratelimiter(rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=0 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=5 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=10 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) def test_pruning(self): limiter = Ratelimiter(rate_hz=0.1, burst_count=1) - _, _ = limiter.can_do_action( - key="test_id_1", time_now_s=0 - ) + _, _ = limiter.can_do_action(key="test_id_1", time_now_s=0) self.assertIn("test_id_1", limiter.actions) - _, _ = limiter.can_do_action( - key="test_id_2", time_now_s=10 - ) + _, _ = limiter.can_do_action(key="test_id_2", time_now_s=10) self.assertNotIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index a34c70f5a715..5b2dcde2ba60 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -56,6 +56,7 @@ def register_query_handler(query_type, handler): federation_server=Mock(), federation_registry=self.mock_registry, request_ratelimiter=NonCallableMock( + # rate_hz and burst_count are overridden in BaseHandler spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 928a3da22375..49d22d9487eb 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -23,7 +23,9 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( federation_client=Mock(), - request_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), + request_ratelimiter=NonCallableMock( + spec_set=["can_do_action", "ratelimit"] + ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 1c42f4063f75..1ceba014940b 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -43,9 +43,10 @@ def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( config=config, request_ratelimiter=NonCallableMock( + # rate_hz and burst_count are overridden in BaseHandler spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] ), - login_ratelimiter = NonCallableMock(spec_set=["can_do_action", "ratelimit"]), + login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) self.request_ratelimiter = hs.get_request_ratelimiter() self.request_ratelimiter.can_do_action.return_value = (True, 0) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a07884f20d36..28b7ce085bdd 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -50,6 +50,7 @@ def make_homeserver(self, reactor, clock): http_client=None, federation_client=Mock(), request_ratelimiter=NonCallableMock( + # rate_hz and burst_count are overridden in BaseHandler spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 30bb6bd34a66..27d38d354aa0 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -43,6 +43,7 @@ def make_homeserver(self, reactor, clock): http_client=None, federation_client=Mock(), request_ratelimiter=NonCallableMock( + # rate_hz and burst_count are overridden in BaseHandler spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), From c322ba00a3481a115dade3fdd1ac987762b09526 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 28 May 2020 22:57:21 +0100 Subject: [PATCH 06/31] changelog --- changelog.d/7595.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/7595.misc diff --git a/changelog.d/7595.misc b/changelog.d/7595.misc new file mode 100644 index 000000000000..db1c12beeb29 --- /dev/null +++ b/changelog.d/7595.misc @@ -0,0 +1 @@ +Refactor `Ratelimiter` and try to limit the amount of related, expensive config value accesses. From f6203a60e099651b51dc3d755dbd6b1c6aa8ce08 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 29 May 2020 18:30:44 +0100 Subject: [PATCH 07/31] Make rate_hz and burst_count overridable per-request --- synapse/api/ratelimiting.py | 66 ++++++++++++++++++++--------- synapse/handlers/_base.py | 14 +++--- tests/handlers/test_profile.py | 3 +- tests/rest/client/v1/test_rooms.py | 4 +- tests/rest/client/v1/test_typing.py | 3 +- 5 files changed, 58 insertions(+), 32 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 13fff302fe67..79b7631172bc 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Any, Tuple +from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError @@ -23,10 +23,8 @@ class Ratelimiter(object): Ratelimit actions marked by arbitrary keys. Args: - rate_hz: The long term number of actions that can be performed in a - second. - burst_count: How many actions that can be performed before being - limited. + rate_hz: The long term number of actions that can be performed in a second. + burst_count: How many actions that can be performed before being limited. """ def __init__(self, rate_hz: float, burst_count: int): @@ -40,7 +38,12 @@ def __init__(self, rate_hz: float, burst_count: int): self.burst_count = burst_count def can_do_action( - self, key: Any, time_now_s: int, update: bool = True, + self, + key: Any, + time_now_s: int, + update: bool = True, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? @@ -49,27 +52,36 @@ def can_do_action( (when sending events), an IP address, etc. time_now_s: The time now update: Whether to count this check as performing the action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + Returns: A tuple containing: * A bool indicating if they can perform the action now * The time in seconds of when it can next be performed. -1 if a rate_hz has not been defined for this Ratelimiter """ + # Override default values if set + rate_hz = rate_hz or self.rate_hz + burst_count = burst_count or self.burst_count + # Remove any expired entries - self._prune_message_counts(time_now_s) + self._prune_message_counts(time_now_s, rate_hz) # Check if there is an existing count entry for this key action_count, time_start, = self.actions.get(key, (0.0, time_now_s)) # Check whether performing another action is allowed time_delta = time_now_s - time_start - performed_count = action_count - time_delta * self.rate_hz + performed_count = action_count - time_delta * rate_hz if performed_count < 0: # Allow, reset back to count 1 allowed = True time_start = time_now_s action_count = 1.0 - elif performed_count > self.burst_count - 1.0: + elif performed_count > burst_count - 1.0: # Deny, we have exceeded our burst count allowed = False else: @@ -82,9 +94,7 @@ def can_do_action( # Figure out the time when an action can be performed again if self.rate_hz > 0: - time_allowed = ( - time_start + (action_count - self.burst_count + 1) / self.rate_hz - ) + time_allowed = time_start + (action_count - burst_count + 1) / rate_hz # Don't give back a time in the past if time_allowed < time_now_s: @@ -95,26 +105,34 @@ def can_do_action( return allowed, time_allowed - def _prune_message_counts(self, time_now_s: int): - """Remove message count entries that are older than + def _prune_message_counts(self, time_now_s: int, rate_hz: float): + """Remove message count entries that have not exceeded their defined + rate_hz limit Args: time_now_s: The current time + rate_hz: The long term number of actions that can be performed in a second. """ # We create a copy of the key list here as the dictionary is modified during # the loop for key in list(self.actions.keys()): action_count, time_start = self.actions[key] + # Rate limit = "seconds since we started limiting this action" * rate_hz + # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start - if action_count - time_delta * self.rate_hz > 0: - # XXX: Should this be a continue? - break + if action_count - time_delta * rate_hz > 0: + continue else: del self.actions[key] def ratelimit( - self, key: Any, time_now_s: int, update: bool = True, + self, + key: Any, + time_now_s: int, + update: bool = True, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, ): """Checks if an action can be performed. If not, raises a LimitExceededError @@ -122,12 +140,22 @@ def ratelimit( key: An arbitrary key used to classify an action time_now_s: The current time update: Whether to count this check as performing the action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. Raises: LimitExceededError: If an action could not be performed, along with the time in milliseconds until the action can be performed again """ - allowed, time_allowed = self.can_do_action(key, time_now_s, update) + # Override default values if set + rate_hz = rate_hz or self.rate_hz + burst_count = burst_count or self.burst_count + + allowed, time_allowed = self.can_do_action( + key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count + ) if not allowed: raise LimitExceededError( diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 206702b6ad6b..e10e2427c439 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -46,7 +46,6 @@ def __init__(self, hs): self.clock = hs.get_clock() self.hs = hs - self.ratelimiter = None self.request_ratelimiter = hs.get_request_ratelimiter() self._rc_message = self.hs.config.rc_message @@ -103,14 +102,17 @@ def ratelimit(self, requester, update=True, is_admin_redaction=False): if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate - # ratelimiter as to not have user_id's clash + # ratelimiter as to not have user_ids clash self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update) else: # Override rate and burst count per-user - self.request_ratelimiter.rate_hz = messages_per_second - self.request_ratelimiter.burst_count = burst_count - - self.request_ratelimiter.ratelimit(user_id, time_now, update) + self.request_ratelimiter.ratelimit( + user_id, + time_now, + update, + rate_hz=messages_per_second, + burst_count=burst_count, + ) async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 5b2dcde2ba60..891c986fbc05 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -56,8 +56,7 @@ def register_query_handler(query_type, handler): federation_server=Mock(), federation_registry=self.mock_registry, request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + spec_set=["can_do_action", "ratelimit"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 28b7ce085bdd..ba10f3446849 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -50,14 +50,12 @@ def make_homeserver(self, reactor, clock): http_client=None, federation_client=Mock(), request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + spec_set=["can_do_action", "ratelimit"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) self.request_ratelimiter = self.hs.get_request_ratelimiter() self.request_ratelimiter.can_do_action.return_value = (True, 0) - self.request_ratelimiter.rate_hz = Mock() self.login_ratelimiter = self.hs.get_login_ratelimiter() self.login_ratelimiter.can_do_action.return_value = (True, 0) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 27d38d354aa0..2ec678a2a2d1 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -43,8 +43,7 @@ def make_homeserver(self, reactor, clock): http_client=None, federation_client=Mock(), request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] + spec_set=["can_do_action", "ratelimit"] ), login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) From 1f6156be4312816e041b317393604702882aa577 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 18:28:50 +0100 Subject: [PATCH 08/31] Set clock with constructor, store rate_hz per key again --- synapse/api/ratelimiting.py | 48 ++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 79b7631172bc..938316f66ffb 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -16,6 +16,7 @@ from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.util import Clock class Ratelimiter(object): @@ -23,24 +24,30 @@ class Ratelimiter(object): Ratelimit actions marked by arbitrary keys. Args: + clock: A homeserver clock, for retrieving the current time rate_hz: The long term number of actions that can be performed in a second. burst_count: How many actions that can be performed before being limited. """ - def __init__(self, rate_hz: float, burst_count: int): + def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + self.clock = clock + self.rate_hz = rate_hz + self.burst_count = burst_count + # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type # to a tuple representing: # * How many times an action has occurred since a point in time - # * That point in time - self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int]] - self.rate_hz = rate_hz - self.burst_count = burst_count + # * The point in time + # * The rate_hz of this particular entry. This can vary per request + self.actions = ( + OrderedDict() + ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] def can_do_action( self, key: Any, - time_now_s: int, + time_now_s: Optional[int] = None, update: bool = True, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, @@ -50,7 +57,8 @@ def can_do_action( Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now + time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Pretty much only used for tests. update: Whether to count this check as performing the action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. @@ -64,14 +72,15 @@ def can_do_action( -1 if a rate_hz has not been defined for this Ratelimiter """ # Override default values if set - rate_hz = rate_hz or self.rate_hz - burst_count = burst_count or self.burst_count + time_now_s = time_now_s if time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count # Remove any expired entries - self._prune_message_counts(time_now_s, rate_hz) + self._prune_message_counts(time_now_s) # Check if there is an existing count entry for this key - action_count, time_start, = self.actions.get(key, (0.0, time_now_s)) + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, None)) # Check whether performing another action is allowed time_delta = time_now_s - time_start @@ -90,7 +99,7 @@ def can_do_action( action_count += 1.0 if update: - self.actions[key] = (action_count, time_start) + self.actions[key] = (action_count, time_start, rate_hz) # Figure out the time when an action can be performed again if self.rate_hz > 0: @@ -105,18 +114,17 @@ def can_do_action( return allowed, time_allowed - def _prune_message_counts(self, time_now_s: int, rate_hz: float): + def _prune_message_counts(self, time_now_s: int): """Remove message count entries that have not exceeded their defined rate_hz limit Args: time_now_s: The current time - rate_hz: The long term number of actions that can be performed in a second. """ # We create a copy of the key list here as the dictionary is modified during # the loop for key in list(self.actions.keys()): - action_count, time_start = self.actions[key] + action_count, time_start, rate_hz = self.actions[key] # Rate limit = "seconds since we started limiting this action" * rate_hz # If this limit has not been exceeded, wipe our record of this action @@ -129,7 +137,7 @@ def _prune_message_counts(self, time_now_s: int, rate_hz: float): def ratelimit( self, key: Any, - time_now_s: int, + time_now_s: Optional[int] = None, update: bool = True, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, @@ -138,7 +146,8 @@ def ratelimit( Args: key: An arbitrary key used to classify an action - time_now_s: The current time + time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Pretty much only used for tests. update: Whether to count this check as performing the action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. @@ -150,8 +159,9 @@ def ratelimit( milliseconds until the action can be performed again """ # Override default values if set - rate_hz = rate_hz or self.rate_hz - burst_count = burst_count or self.burst_count + time_now_s = time_now_s if time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count allowed, time_allowed = self.can_do_action( key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count From c236806a0894f6c4849c95a3e4f5e0a4bc3bba86 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 18:31:31 +0100 Subject: [PATCH 09/31] Instantiate Ratelimiters in respective classes Other than the registration ratelimiter, as that's used in multiple classes and needs to keep the same state --- synapse/handlers/_base.py | 18 +++++++-- synapse/handlers/auth.py | 11 ++---- synapse/handlers/register.py | 6 +-- synapse/rest/client/v1/login.py | 48 ++++++++++++++---------- synapse/rest/client/v2_alpha/register.py | 12 +----- synapse/server.py | 31 +-------------- 6 files changed, 51 insertions(+), 75 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index e10e2427c439..44cc364ad815 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,6 +19,7 @@ import synapse.types from synapse.api.constants import EventTypes, Membership +from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID logger = logging.getLogger(__name__) @@ -46,11 +47,22 @@ def __init__(self, hs): self.clock = hs.get_clock() self.hs = hs - self.request_ratelimiter = hs.get_request_ratelimiter() + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=0, burst_count=0 + ) self._rc_message = self.hs.config.rc_message - # If special admin redaction ratelimiting is disabled, this will be None - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if self.hs.config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_admin_redaction.per_second, + burst_count=self.hs.config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None self.server_name = hs.hostname diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 089c94f8b679..893491166121 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -110,6 +110,7 @@ def __init__(self, hs): # as per `rc_login.failed_attempts`. # XXX: Should this be hs.get_login_failed_attempts_ratelimiter? self._failed_uia_attempts_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) @@ -200,9 +201,7 @@ async def validate_user_via_ui_auth( user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit( - user_id, time_now_s=self._clock.time(), update=False, - ) + self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows flows = [[login_type] for login_type in self._supported_ui_auth_types] @@ -212,10 +211,8 @@ async def validate_user_via_ui_auth( flows, request, request_body, clientip, description ) except LoginError: - # Update the ratelimite to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action( - user_id, time_now_s=self._clock.time(), update=True, - ) + # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). + self._failed_uia_attempts_ratelimiter.can_do_action(user_id) raise # find the completed login type diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ce18b33a634b..a138f9557d73 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -425,11 +425,7 @@ def check_registration_ratelimit(self, address): if not address: return - time_now = self.clock.time() - - self.ratelimiter.ratelimit( - address, time_now_s=time_now, - ) + self.ratelimiter.ratelimit(address) def register_with_store( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 19c392849a3e..a6e22bd4b8b6 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -16,6 +16,7 @@ import logging from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -86,10 +87,29 @@ def __init__(self, hs): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() - self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) - self._account_ratelimiter = hs.get_login_ratelimiter() - self._failed_attempts_ratelimiter = hs.get_login_failed_attempts_ratelimiter() + self._address_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + print( + "Creating fail ratelimiter: %s %s" + % ( + self.hs.config.rc_login_failed_attempts.per_second, + self.hs.config.rc_login_failed_attempts.burst_count, + ), + ) + self._failed_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) def on_GET(self, request): flows = [] @@ -127,9 +147,7 @@ def on_OPTIONS(self, request): return 200, {} async def on_POST(self, request): - self._account_ratelimiter.ratelimit( - request.getClientIP(), time_now_s=self.hs.clock.time(), update=True, - ) + self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: @@ -197,9 +215,7 @@ async def _do_other_login(self, login_submission): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit( - (medium, address), time_now_s=self._clock.time(), update=False, - ) + self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) # Check for login providers that support 3pid login types ( @@ -233,9 +249,7 @@ async def _do_other_login(self, login_submission): # If it returned None but the 3PID was bound then we won't hit # this code path, which is fine as then the per-user ratelimit # will kick in below. - self._failed_attempts_ratelimiter.can_do_action( - (medium, address), time_now_s=self._clock.time(), update=True, - ) + self._failed_attempts_ratelimiter.can_do_action((medium, address)) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -254,7 +268,7 @@ async def _do_other_login(self, login_submission): # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), time_now_s=self._clock.time(), update=False, + qualified_user_id.lower(), update=False ) try: @@ -266,9 +280,7 @@ async def _do_other_login(self, login_submission): # limiter. Using `can_do_action` avoids us raising a ratelimit # exception and masking the LoginError. The actual ratelimiting # should have happened above. - self._failed_attempts_ratelimiter.can_do_action( - qualified_user_id.lower(), time_now_s=self._clock.time(), update=True, - ) + self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) raise result = await self._complete_login( @@ -301,9 +313,7 @@ async def _complete_login( # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit( - user_id.lower(), time_now_s=self._clock.time(), update=True, - ) + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existant_users: user_id = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 8567cbcab3a8..b9ffe86b2afe 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -26,7 +26,6 @@ from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, - LimitExceededError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -396,16 +395,7 @@ async def on_POST(self, request): client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, time_now_s=time_now, update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index fc39b57135fa..fe94836a2c9e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -24,7 +24,6 @@ import abc import logging import os -from typing import Optional from twisted.mail.smtp import sendmail @@ -243,28 +242,12 @@ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwar self.clock = Clock(reactor) self.distributor = Distributor() - # The rate_hz and burst_count is overridden on a per-user basis - self.request_ratelimiter = Ratelimiter(rate_hz=0, burst_count=0,) - if config.rc_admin_redaction: - self.admin_redaction_ratelimiter = Ratelimiter( - rate_hz=config.rc_admin_redaction.per_second, - burst_count=config.rc_admin_redaction.burst_count, - ) - else: - self.admin_redaction_ratelimiter = None self.registration_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=config.rc_registration.per_second, burst_count=config.rc_registration.burst_count, ) - self.login_ratelimiter = Ratelimiter( - rate_hz=config.rc_login_account.per_second, - burst_count=config.rc_login_account.burst_count, - ) - self.login_failed_attempts_ratelimiter = Ratelimiter( - rate_hz=config.rc_login_failed_attempts.per_second, - burst_count=config.rc_login_failed_attempts.burst_count, - ) self.datastores = None @@ -334,21 +317,9 @@ def get_config(self): def get_distributor(self): return self.distributor - def get_request_ratelimiter(self) -> Ratelimiter: - return self.request_ratelimiter - def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self) -> Optional[Ratelimiter]: - return self.admin_redaction_ratelimiter - - def get_login_ratelimiter(self) -> Ratelimiter: - return self.login_ratelimiter - - def get_login_failed_attempts_ratelimiter(self) -> Ratelimiter: - return self.login_failed_attempts_ratelimiter - def build_federation_client(self): return FederationClient(self) From 470de6e3ef209704baf60041b7abdc05dc943466 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 18:32:06 +0100 Subject: [PATCH 10/31] Use patch for the Ratelimiter in some tests. Set using config in others --- tests/api/test_ratelimiting.py | 4 +- tests/handlers/test_profile.py | 21 ++++----- tests/replication/slave/storage/_base.py | 27 ++++++----- tests/rest/client/v1/test_events.py | 25 +++++------ tests/rest/client/v1/test_login.py | 50 ++++++++++++++++----- tests/rest/client/v1/test_rooms.py | 24 +++++----- tests/rest/client/v1/test_typing.py | 27 ++++++----- tests/rest/client/v2_alpha/test_register.py | 23 +++++++--- 8 files changed, 120 insertions(+), 81 deletions(-) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 973c7e007c51..12425b1faacc 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,7 +5,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed(self): - limiter = Ratelimiter(rate_hz=0.1, burst_count=1) + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) @@ -19,7 +19,7 @@ def test_allowed(self): self.assertEquals(20.0, time_allowed) def test_pruning(self): - limiter = Ratelimiter(rate_hz=0.1, burst_count=1) + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) _, _ = limiter.can_do_action(key="test_id_1", time_now_s=0) self.assertIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 891c986fbc05..cf9026d3a79f 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,12 +14,13 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock, patch from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -55,17 +56,17 @@ def register_query_handler(query_type, handler): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), ) - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) - - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, + "can_do_action", + new_callable=lambda *args, **kwargs: (True, 0.0), + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) self.store = hs.get_datastore() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 49d22d9487eb..f3e0af6460f0 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock, patch + +from synapse.api.ratelimiting import Ratelimiter from tests.replication._base import BaseStreamTestCase @@ -21,20 +23,17 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - federation_client=Mock(), - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), - ) - - # Prevent ratelimiting - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(federation_client=Mock(),) - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, + "can_do_action", + new_callable=lambda *args, **kwargs: (True, 0.0), + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 1ceba014940b..95dd24fa5c27 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,9 +15,10 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock, patch import synapse.rest.admin +from synapse.api.ratelimiting import Ratelimiter from synapse.rest.client.v1 import events, login, room from tests import unittest @@ -40,19 +41,17 @@ def make_homeserver(self, reactor, clock): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver( - config=config, - request_ratelimiter=NonCallableMock( - # rate_hz and burst_count are overridden in BaseHandler - spec_set=["can_do_action", "ratelimit", "rate_hz", "burst_count"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), - ) - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(config=config,) - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, + "can_do_action", + new_callable=lambda *args, **kwargs: (True, 0.0), + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index c01738ed6918..619670b1b294 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -26,7 +26,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] @@ -35,10 +34,20 @@ def make_homeserver(self, reactor, clock): return self.hs + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5,}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000,}, + } + } + ) def test_POST_ratelimiting_per_address(self): - self.hs.get_login_ratelimiter().burst_count = 5 - self.hs.get_login_ratelimiter().rate_hz = 0.17 - # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -77,10 +86,20 @@ def test_POST_ratelimiting_per_address(self): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5,}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000,}, + } + } + ) def test_POST_ratelimiting_per_account(self): - self.hs.get_login_ratelimiter().burst_count = 5 - self.hs.get_login_ratelimiter().rate_hz = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -116,10 +135,21 @@ def test_POST_ratelimiting_per_account(self): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000,}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5,}, + } + } + ) + @unittest.DEBUG def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.get_login_failed_attempts_ratelimiter().burst_count = 5 - self.hs.get_login_failed_attempts_ratelimiter().rate_hz = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ba10f3446849..af2617d7e391 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,13 +20,14 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock, patch from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account @@ -46,19 +47,18 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), + "red", http_client=None, federation_client=Mock(), ) - self.request_ratelimiter = self.hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) - self.login_ratelimiter = self.hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, + "can_do_action", + new_callable=lambda *args, **kwargs: (True, 0.0), + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 2ec678a2a2d1..4d9f882bf1d1 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,10 +16,11 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock, patch from twisted.internet import defer +from synapse.api.ratelimiting import Ratelimiter from synapse.rest.client.v1 import room from synapse.types import UserID @@ -39,22 +40,20 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - request_ratelimiter=NonCallableMock( - spec_set=["can_do_action", "ratelimit"] - ), - login_ratelimiter=NonCallableMock(spec_set=["can_do_action", "ratelimit"]), + "red", http_client=None, federation_client=Mock(), ) - self.event_source = hs.get_event_sources().sources["typing"] - - self.request_ratelimiter = hs.get_request_ratelimiter() - self.request_ratelimiter.can_do_action.return_value = (True, 0) + # Patch Ratelimiter to allow all requests + patch.object( + Ratelimiter, + "can_do_action", + new_callable=lambda *args, **kwargs: (True, 0.0), + ) + patch.object( + Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ) - self.login_ratelimiter = hs.get_login_ratelimiter() - self.login_ratelimiter.can_do_action.return_value = (True, 0) + self.event_source = hs.get_event_sources().sources["typing"] hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index c64b65889276..761fa97684c4 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -29,6 +29,7 @@ from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest +from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): @@ -146,10 +147,15 @@ def test_POST_disabled_guest_registration(self): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + @override_config( + { + "rc_registration": { + "per_second": 0.17, + "burst_count": 5, + } + } + ) def test_POST_ratelimiting_guest(self): - self.hs.get_registration_ratelimiter().burst_count = 5 - self.hs.get_registration_ratelimiter().rate_hz = 0.17 - for i in range(0, 6): url = self.url + b"?kind=guest" request, channel = self.make_request(b"POST", url, b"{}") @@ -168,10 +174,15 @@ def test_POST_ratelimiting_guest(self): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_registration": { + "per_second": 0.17, + "burst_count": 5, + } + } + ) def test_POST_ratelimiting(self): - self.hs.get_registration_ratelimiter().burst_count = 5 - self.hs.get_registration_ratelimiter().rate_hz = 0.17 - for i in range(0, 6): params = { "username": "kermit" + str(i), From 515a18607c5a8ff67e315000678d5cb11d30aaae Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 18:35:45 +0100 Subject: [PATCH 11/31] Update copyright header --- synapse/api/ratelimiting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 938316f66ffb..3b0c6a3b0499 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 87ab83631eb3b1ff661387ee34cea48118ed61e1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 19:06:07 +0100 Subject: [PATCH 12/31] Remove resolved question --- synapse/handlers/auth.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 893491166121..119678e67ba9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -108,7 +108,6 @@ def __init__(self, hs): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. - # XXX: Should this be hs.get_login_failed_attempts_ratelimiter? self._failed_uia_attempts_ratelimiter = Ratelimiter( clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, From 56c52a56ba694ec4c4dacd02254752f3d435f29a Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 19:19:47 +0100 Subject: [PATCH 13/31] lint --- tests/rest/client/v1/test_login.py | 12 ++++++------ tests/rest/client/v2_alpha/test_register.py | 18 ++---------------- 2 files changed, 8 insertions(+), 22 deletions(-) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 619670b1b294..b4969d8dd389 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -37,13 +37,13 @@ def make_homeserver(self, reactor, clock): @override_config( { "rc_login": { - "address": {"per_second": 0.17, "burst_count": 5,}, + "address": {"per_second": 0.17, "burst_count": 5}, # Prevent the account login ratelimiter from raising first # # This is normally covered by the default test homeserver config # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well - "account": {"per_second": 10000, "burst_count": 10000,}, + "account": {"per_second": 10000, "burst_count": 10000}, } } ) @@ -89,13 +89,13 @@ def test_POST_ratelimiting_per_address(self): @override_config( { "rc_login": { - "account": {"per_second": 0.17, "burst_count": 5,}, + "account": {"per_second": 0.17, "burst_count": 5}, # Prevent the address login ratelimiter from raising first # # This is normally covered by the default test homeserver config # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000,}, + "address": {"per_second": 10000, "burst_count": 10000}, } } ) @@ -143,8 +143,8 @@ def test_POST_ratelimiting_per_account(self): # This is normally covered by the default test homeserver config # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000,}, - "failed_attempts": {"per_second": 0.17, "burst_count": 5,}, + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, } } ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 761fa97684c4..0f67c69ab3c5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -147,14 +147,7 @@ def test_POST_disabled_guest_registration(self): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") - @override_config( - { - "rc_registration": { - "per_second": 0.17, - "burst_count": 5, - } - } - ) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5,}}) def test_POST_ratelimiting_guest(self): for i in range(0, 6): url = self.url + b"?kind=guest" @@ -174,14 +167,7 @@ def test_POST_ratelimiting_guest(self): self.assertEquals(channel.result["code"], b"200", channel.result) - @override_config( - { - "rc_registration": { - "per_second": 0.17, - "burst_count": 5, - } - } - ) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5,}}) def test_POST_ratelimiting(self): for i in range(0, 6): params = { From 2d7e0879eba46be41abced9f6122c18d1d8c728a Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 19:25:55 +0100 Subject: [PATCH 14/31] lint, mypy --- synapse/api/ratelimiting.py | 6 ++---- tests/rest/client/v2_alpha/test_register.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 3b0c6a3b0499..44ebbf38e066 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -41,9 +41,7 @@ def __init__(self, clock: Clock, rate_hz: float, burst_count: int): # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions = ( - OrderedDict() - ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] + self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] def can_do_action( self, @@ -81,7 +79,7 @@ def can_do_action( self._prune_message_counts(time_now_s) # Check if there is an existing count entry for this key - action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, None)) + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) # Check whether performing another action is allowed time_delta = time_now_s - time_start diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 0f67c69ab3c5..7deaf5b24a48 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -147,7 +147,7 @@ def test_POST_disabled_guest_registration(self): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5,}}) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): for i in range(0, 6): url = self.url + b"?kind=guest" @@ -167,7 +167,7 @@ def test_POST_ratelimiting_guest(self): self.assertEquals(channel.result["code"], b"200", channel.result) - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5,}}) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): for i in range(0, 6): params = { From a566b46ec774f383aac9aa5b6024cb78a9066188 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 1 Jun 2020 19:38:20 +0100 Subject: [PATCH 15/31] Remove unittest.DEBUG statement --- tests/rest/client/v1/test_login.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index b4969d8dd389..99c70ae50cd7 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -148,7 +148,6 @@ def test_POST_ratelimiting_per_account(self): } } ) - @unittest.DEBUG def test_POST_ratelimiting_per_account_failed_attempts(self): self.register_user("kermit", "monkey") From 41c72888f595f6093c1b08cb243cabec06ba3ae5 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 2 Jun 2020 18:48:00 +0100 Subject: [PATCH 16/31] Update changelog.d/7595.misc Co-authored-by: Patrick Cloke --- changelog.d/7595.misc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/7595.misc b/changelog.d/7595.misc index db1c12beeb29..7a0646b1a3fe 100644 --- a/changelog.d/7595.misc +++ b/changelog.d/7595.misc @@ -1 +1 @@ -Refactor `Ratelimiter` and try to limit the amount of related, expensive config value accesses. +Refactor `Ratelimiter` to limit the amount of expensive config value accesses. From 58d4919b5c9252c4b10ca5e42e55a693a97a7eaf Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 2 Jun 2020 18:52:24 +0100 Subject: [PATCH 17/31] Remove erroneous print statement --- synapse/rest/client/v1/login.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index a6e22bd4b8b6..330e19ccd1e7 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -98,13 +98,6 @@ def __init__(self, hs): rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, ) - print( - "Creating fail ratelimiter: %s %s" - % ( - self.hs.config.rc_login_failed_attempts.per_second, - self.hs.config.rc_login_failed_attempts.burst_count, - ), - ) self._failed_attempts_ratelimiter = Ratelimiter( clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, From 39b484bd97134ff0057725be83d39914783e570a Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 2 Jun 2020 19:29:32 +0100 Subject: [PATCH 18/31] Move update after optional method arguments --- synapse/api/ratelimiting.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 44ebbf38e066..919d6f402e98 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -47,9 +47,9 @@ def can_do_action( self, key: Any, time_now_s: Optional[int] = None, - update: bool = True, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, + update: bool = True, ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? @@ -58,11 +58,11 @@ def can_do_action( (when sending events), an IP address, etc. time_now_s: The current time. Optional, defaults to the current time according to self.clock. Pretty much only used for tests. - update: Whether to count this check as performing the action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action Returns: A tuple containing: @@ -137,9 +137,9 @@ def ratelimit( self, key: Any, time_now_s: Optional[int] = None, - update: bool = True, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, + update: bool = True, ): """Checks if an action can be performed. If not, raises a LimitExceededError @@ -147,11 +147,11 @@ def ratelimit( key: An arbitrary key used to classify an action time_now_s: The current time. Optional, defaults to the current time according to self.clock. Pretty much only used for tests. - update: Whether to count this check as performing the action rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action Raises: LimitExceededError: If an action could not be performed, along with the time in @@ -163,7 +163,7 @@ def ratelimit( burst_count = burst_count if burst_count is not None else self.burst_count allowed, time_allowed = self.can_do_action( - key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count + key, time_now_s, rate_hz=rate_hz, burst_count=burst_count, update=update ) if not allowed: From d727bed653d217093b5aa6e3f8f247dc0b4f9b2e Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 2 Jun 2020 19:33:50 +0100 Subject: [PATCH 19/31] Make it obvious that time_now_s is just for testing --- synapse/api/ratelimiting.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 919d6f402e98..e5d75e71e1d4 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -46,23 +46,23 @@ def __init__(self, clock: Clock, rate_hz: float, burst_count: int): def can_do_action( self, key: Any, - time_now_s: Optional[int] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, + _time_now_s: Optional[int] = None, ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Pretty much only used for tests. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. Overrides the value set during instantiation if set. update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. Returns: A tuple containing: @@ -71,7 +71,7 @@ def can_do_action( -1 if a rate_hz has not been defined for this Ratelimiter """ # Override default values if set - time_now_s = time_now_s if time_now_s is not None else self.clock.time() + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz burst_count = burst_count if burst_count is not None else self.burst_count @@ -136,34 +136,38 @@ def _prune_message_counts(self, time_now_s: int): def ratelimit( self, key: Any, - time_now_s: Optional[int] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, + _time_now_s: Optional[int] = None, ): """Checks if an action can be performed. If not, raises a LimitExceededError Args: key: An arbitrary key used to classify an action - time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Pretty much only used for tests. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. Overrides the value set during instantiation if set. update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. Raises: LimitExceededError: If an action could not be performed, along with the time in milliseconds until the action can be performed again """ # Override default values if set - time_now_s = time_now_s if time_now_s is not None else self.clock.time() + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz burst_count = burst_count if burst_count is not None else self.burst_count allowed, time_allowed = self.can_do_action( - key, time_now_s, rate_hz=rate_hz, burst_count=burst_count, update=update + key, + rate_hz=rate_hz, + burst_count=burst_count, + update=update, + _time_now_s=time_now_s, ) if not allowed: From 9f76a8d8011c6a23ac56f9db9a4d7fd415321958 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 2 Jun 2020 19:41:39 +0100 Subject: [PATCH 20/31] Update ratelimiter calling methods and tests --- synapse/handlers/_base.py | 6 ++---- tests/api/test_ratelimiting.py | 10 +++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 44cc364ad815..61dc4beafef0 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -85,7 +85,6 @@ def ratelimit(self, requester, update=True, is_admin_redaction=False): Raises: LimitExceededError if the request should be ratelimited """ - time_now = self.clock.time() user_id = requester.user.to_string() # The AS user itself is never rate limited. @@ -115,15 +114,14 @@ def ratelimit(self, requester, update=True, is_admin_redaction=False): if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update) + self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) else: # Override rate and burst count per-user self.request_ratelimiter.ratelimit( user_id, - time_now, - update, rate_hz=messages_per_second, burst_count=burst_count, + update=update, ) async def maybe_kick_guest_users(self, event, context=None): diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 12425b1faacc..95506458d47e 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -6,24 +6,24 @@ class TestRatelimiter(unittest.TestCase): def test_allowed(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=0) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=5) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=10) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) def test_pruning(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - _, _ = limiter.can_do_action(key="test_id_1", time_now_s=0) + _, _ = limiter.can_do_action(key="test_id_1", _time_now_s=0) self.assertIn("test_id_1", limiter.actions) - _, _ = limiter.can_do_action(key="test_id_2", time_now_s=10) + _, _ = limiter.can_do_action(key="test_id_2", _time_now_s=10) self.assertNotIn("test_id_1", limiter.actions) From 88679001d79929d38311c24d0ca9037c5e5a7dc8 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 2 Jun 2020 19:42:30 +0100 Subject: [PATCH 21/31] No need to re-check for None in can_do_action --- synapse/api/ratelimiting.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index e5d75e71e1d4..69006ce76b2e 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -157,20 +157,15 @@ def ratelimit( LimitExceededError: If an action could not be performed, along with the time in milliseconds until the action can be performed again """ - # Override default values if set - time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - rate_hz = rate_hz if rate_hz is not None else self.rate_hz - burst_count = burst_count if burst_count is not None else self.burst_count - allowed, time_allowed = self.can_do_action( key, rate_hz=rate_hz, burst_count=burst_count, update=update, - _time_now_s=time_now_s, + _time_now_s=_time_now_s, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) + retry_after_ms=int(1000 * (time_allowed - _time_now_s)) ) From ef7383fb266b0e3ce9e3a02064133e2315aa5190 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 3 Jun 2020 13:49:24 +0100 Subject: [PATCH 22/31] time_now_s is used in ratelimit --- synapse/api/ratelimiting.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 69006ce76b2e..4a5ba8c1aff4 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -157,15 +157,17 @@ def ratelimit( LimitExceededError: If an action could not be performed, along with the time in milliseconds until the action can be performed again """ + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + allowed, time_allowed = self.can_do_action( key, rate_hz=rate_hz, burst_count=burst_count, update=update, - _time_now_s=_time_now_s, + _time_now_s=time_now_s, ) if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - _time_now_s)) + retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) From 189c01b7fb4628737dd1adb1367f3d2d177e955d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 3 Jun 2020 14:26:10 +0100 Subject: [PATCH 23/31] Comment changes revolving around time_allowed --- synapse/api/ratelimiting.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 4a5ba8c1aff4..8eea3787aa50 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -68,7 +68,7 @@ def can_do_action( A tuple containing: * A bool indicating if they can perform the action now * The time in seconds of when it can next be performed. - -1 if a rate_hz has not been defined for this Ratelimiter + -1 if rate_hz is less than or equal to zero """ # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() @@ -100,15 +100,17 @@ def can_do_action( if update: self.actions[key] = (action_count, time_start, rate_hz) - # Figure out the time when an action can be performed again if self.rate_hz > 0: + # Find out when the count of existing actions expires time_allowed = time_start + (action_count - burst_count + 1) / rate_hz # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s else: - # This does not apply + # XXX: Why is this -1? This seems to only be used in + # self.ratelimit. I guess so that clients get a time in the past and don't + # feel afraid to try again immediately time_allowed = -1 return allowed, time_allowed From 4a88edb16c58a77f20a7a949a35b230ed7fe023f Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 3 Jun 2020 14:43:38 +0100 Subject: [PATCH 24/31] Fix missed call to self.rate_hz --- synapse/api/ratelimiting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 8eea3787aa50..835aaa0ad64c 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -100,13 +100,14 @@ def can_do_action( if update: self.actions[key] = (action_count, time_start, rate_hz) - if self.rate_hz > 0: + if rate_hz > 0: # Find out when the count of existing actions expires time_allowed = time_start + (action_count - burst_count + 1) / rate_hz # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s + else: # XXX: Why is this -1? This seems to only be used in # self.ratelimit. I guess so that clients get a time in the past and don't From 14a0af5c23c3a2c02e481a3958624c725b7127e5 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Wed, 3 Jun 2020 14:02:48 +0100 Subject: [PATCH 25/31] Test Ratelimiter ratelimit method and param overrides --- tests/api/test_ratelimiting.py | 84 +++++++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 95506458d47e..82562dc0a4f9 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,4 +1,4 @@ -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from tests import unittest @@ -18,6 +18,88 @@ def test_allowed(self): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_via_ratelimit(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=0) + + # Should raise + self.assertRaises( + LimitExceededError, limiter.ratelimit, key="test_id", _time_now_s=5, + ) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=10) + + def test_allowed_by_overriding_parameters(self): + """Test that we can override options of a Ratelimiter that would otherwise fail + an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + time_now = 0 + expected_allowed = 10.0 + + # Shouldn't raise + limiter.ratelimit(("test_id",), _time_now_s=time_now, update=False) + + allowed, time_allowed = limiter.can_do_action( + key=("test_id",), _time_now_s=time_now, + ) + self.assertTrue(allowed) + self.assertEquals(expected_allowed, time_allowed) + + # Second attempt, 1s later, will fail + time_now = 1 + expected_allowed = 10.0 + + # We expect a LimitExceededError to be raised + try: + limiter.ratelimit(("test_id",), _time_now_s=time_now, update=False) + + # We shouldn't reach here + self.assertTrue(False, "LimitExceededError was not raised") + except LimitExceededError as e: + self.assertEquals(e.retry_after_ms / 1000, expected_allowed - time_now) + + allowed, time_allowed = limiter.can_do_action( + key=("test_id",), _time_now_s=time_now, + ) + self.assertFalse(allowed) + self.assertEquals(expected_allowed, time_allowed) + + # But, if we allow 10 actions/sec in this specific instance, we should be allowed + # to continue. burst_count is still 1.0 + time_now = 1 + expected_allowed = 1.1 # Changing rate_hz scales our time_allowed + + # Shouldn't raise + limiter.ratelimit( + key=("test_id",), _time_now_s=time_now, rate_hz=10, update=False + ) + + allowed, time_allowed = limiter.can_do_action( + key=("test_id",), _time_now_s=time_now, rate_hz=10, + ) + self.assertTrue(allowed) + self.assertEquals(expected_allowed, time_allowed) + + # Similarly if we allow a burst of 10 actions, but a rate_hz of 0.1 + time_now = 1 + expected_allowed = 1.0 + limiter.ratelimit( + key=("test_id",), _time_now_s=time_now, burst_count=10, update=False, + ) + + allowed, time_allowed = limiter.can_do_action( + key=("test_id",), _time_now_s=time_now, burst_count=10, + ) + self.assertTrue(allowed) + self.assertEquals(expected_allowed, time_allowed) + def test_pruning(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) _, _ = limiter.can_do_action(key="test_id_1", _time_now_s=0) From c145c810f222d3dd3873eff7801c0925a315417e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 4 Jun 2020 11:54:42 -0400 Subject: [PATCH 26/31] Back out some changes. --- tests/handlers/test_profile.py | 14 ++------------ tests/replication/slave/storage/_base.py | 17 ++++------------- tests/rest/client/v1/test_events.py | 15 +++------------ tests/rest/client/v1/test_rooms.py | 18 +++++------------- tests/rest/client/v1/test_typing.py | 18 +++++------------- 5 files changed, 19 insertions(+), 63 deletions(-) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index cf9026d3a79f..5898f39c7a0f 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,13 +14,12 @@ # limitations under the License. -from mock import Mock, patch +from mock import Mock, NonCallableMock from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError, SynapseError -from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID @@ -56,16 +55,7 @@ def register_query_handler(query_type, handler): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ) - - # Patch Ratelimiter to allow all requests - patch.object( - Ratelimiter, - "can_do_action", - new_callable=lambda *args, **kwargs: (True, 0.0), - ) - patch.object( - Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.store = hs.get_datastore() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index f3e0af6460f0..59f109781d33 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, patch - -from synapse.api.ratelimiting import Ratelimiter +from mock import Mock, NonCallableMock from tests.replication._base import BaseStreamTestCase @@ -23,16 +21,9 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver(federation_client=Mock(),) - - # Patch Ratelimiter to allow all requests - patch.object( - Ratelimiter, - "can_do_action", - new_callable=lambda *args, **kwargs: (True, 0.0), - ) - patch.object( - Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + hs = self.setup_test_homeserver( + federation_client=Mock(), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 95dd24fa5c27..dcad433bef76 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,10 +15,9 @@ """ Tests REST events for /events paths.""" -from mock import Mock, patch +from mock import Mock, NonCallableMock import synapse.rest.admin -from synapse.api.ratelimiting import Ratelimiter from synapse.rest.client.v1 import events, login, room from tests import unittest @@ -41,16 +40,8 @@ def make_homeserver(self, reactor, clock): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver(config=config,) - - # Patch Ratelimiter to allow all requests - patch.object( - Ratelimiter, - "can_do_action", - new_callable=lambda *args, **kwargs: (True, 0.0), - ) - patch.object( - Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + hs = self.setup_test_homeserver( + config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) ) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index af2617d7e391..1ec600be686f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,14 +20,13 @@ import json -from mock import Mock, patch +from mock import Mock, NonCallableMock from six.moves.urllib import parse as urlparse from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership -from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import account @@ -47,17 +46,10 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), - ) - - # Patch Ratelimiter to allow all requests - patch.object( - Ratelimiter, - "can_do_action", - new_callable=lambda *args, **kwargs: (True, 0.0), - ) - patch.object( - Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + "red", + http_client=None, + federation_client=Mock(), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 4d9f882bf1d1..f46a207a6f92 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,11 +16,10 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, patch +from mock import Mock, NonCallableMock from twisted.internet import defer -from synapse.api.ratelimiting import Ratelimiter from synapse.rest.client.v1 import room from synapse.types import UserID @@ -40,17 +39,10 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock(), - ) - - # Patch Ratelimiter to allow all requests - patch.object( - Ratelimiter, - "can_do_action", - new_callable=lambda *args, **kwargs: (True, 0.0), - ) - patch.object( - Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None + "red", + http_client=None, + federation_client=Mock(), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.event_source = hs.get_event_sources().sources["typing"] From 12b4d478ca51489e10ad568ebab6710fbcfe24d1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 4 Jun 2020 12:03:16 -0400 Subject: [PATCH 27/31] Do not specify ratelimiters in tests when unnecessary. --- tests/handlers/test_profile.py | 3 +-- tests/replication/slave/storage/_base.py | 7 ++----- tests/rest/client/v1/test_events.py | 6 ++---- tests/rest/client/v1/test_rooms.py | 7 ++----- tests/rest/client/v1/test_typing.py | 7 ++----- 5 files changed, 9 insertions(+), 21 deletions(-) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 5898f39c7a0f..29dd7d9c6e9e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,7 +14,7 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -55,7 +55,6 @@ def register_query_handler(query_type, handler): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.store = hs.get_datastore() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 59f109781d33..56497b8476ee 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from tests.replication._base import BaseStreamTestCase @@ -21,10 +21,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), - ) + hs = self.setup_test_homeserver(federation_client=Mock()) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index dcad433bef76..f75520877f6f 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,7 +15,7 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import events, login, room @@ -40,9 +40,7 @@ def make_homeserver(self, reactor, clock): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) - ) + hs = self.setup_test_homeserver(config=config) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 1ec600be686f..4886bbb401c1 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,7 +20,7 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from six.moves.urllib import parse as urlparse from twisted.internet import defer @@ -46,10 +46,7 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index f46a207a6f92..18260bb90e2e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,7 +16,7 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -39,10 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] From d84d7793ade45977ab9d3d2165610e3e06944590 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 4 Jun 2020 19:21:19 +0100 Subject: [PATCH 28/31] Update timestamp comment --- synapse/api/ratelimiting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 835aaa0ad64c..ec6b3a69a2af 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -67,7 +67,7 @@ def can_do_action( Returns: A tuple containing: * A bool indicating if they can perform the action now - * The time in seconds of when it can next be performed. + * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ # Override default values if set From 45a779107bada7181f151b06e838fe4f802d2033 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 4 Jun 2020 19:24:13 +0100 Subject: [PATCH 29/31] Clean up Exception raising assertion --- tests/api/test_ratelimiting.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 82562dc0a4f9..89501b5e1c67 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -57,13 +57,12 @@ def test_allowed_by_overriding_parameters(self): expected_allowed = 10.0 # We expect a LimitExceededError to be raised - try: + with self.assertRaises(LimitExceededError) as limit_exception: limiter.ratelimit(("test_id",), _time_now_s=time_now, update=False) - # We shouldn't reach here - self.assertTrue(False, "LimitExceededError was not raised") - except LimitExceededError as e: - self.assertEquals(e.retry_after_ms / 1000, expected_allowed - time_now) + self.assertEquals( + limit_exception.retry_after_ms / 1000, expected_allowed - time_now, + ) allowed, time_allowed = limiter.can_do_action( key=("test_id",), _time_now_s=time_now, From 38995899ff7ec68963a8f07d1673333519f16241 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 4 Jun 2020 19:52:25 +0100 Subject: [PATCH 30/31] Clean up and split out tests --- tests/api/test_ratelimiting.py | 89 +++++++++++++++------------------- 1 file changed, 38 insertions(+), 51 deletions(-) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 89501b5e1c67..389b2579f478 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -4,7 +4,7 @@ class TestRatelimiter(unittest.TestCase): - def test_allowed(self): + def test_allowed_via_can_do_action(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) self.assertTrue(allowed) @@ -25,79 +25,66 @@ def test_allowed_via_ratelimit(self): limiter.ratelimit(key="test_id", _time_now_s=0) # Should raise - self.assertRaises( - LimitExceededError, limiter.ratelimit, key="test_id", _time_now_s=5, - ) + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key="test_id", _time_now_s=5) + self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise limiter.ratelimit(key="test_id", _time_now_s=10) - def test_allowed_by_overriding_parameters(self): - """Test that we can override options of a Ratelimiter that would otherwise fail + def test_allowed_via_can_do_action_and_overriding_parameters(self): + """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) # First attempt should be allowed - time_now = 0 - expected_allowed = 10.0 - - # Shouldn't raise - limiter.ratelimit(("test_id",), _time_now_s=time_now, update=False) - - allowed, time_allowed = limiter.can_do_action( - key=("test_id",), _time_now_s=time_now, - ) + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,) self.assertTrue(allowed) - self.assertEquals(expected_allowed, time_allowed) + self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - time_now = 1 - expected_allowed = 10.0 - - # We expect a LimitExceededError to be raised - with self.assertRaises(LimitExceededError) as limit_exception: - limiter.ratelimit(("test_id",), _time_now_s=time_now, update=False) + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,) + self.assertFalse(allowed) + self.assertEqual(10.0, time_allowed) - self.assertEquals( - limit_exception.retry_after_ms / 1000, expected_allowed - time_now, + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. + allowed, time_allowed = limiter.can_do_action( + ("test_id",), _time_now_s=1, rate_hz=10.0 ) + self.assertTrue(allowed) + self.assertEqual(1.1, time_allowed) + # Similarly if we allow a burst of 10 actions allowed, time_allowed = limiter.can_do_action( - key=("test_id",), _time_now_s=time_now, + ("test_id",), _time_now_s=1, burst_count=10 ) - self.assertFalse(allowed) - self.assertEquals(expected_allowed, time_allowed) + self.assertTrue(allowed) + self.assertEqual(1.0, time_allowed) - # But, if we allow 10 actions/sec in this specific instance, we should be allowed - # to continue. burst_count is still 1.0 - time_now = 1 - expected_allowed = 1.1 # Changing rate_hz scales our time_allowed + def test_allowed_via_ratelimit_and_overriding_parameters(self): + """Test that we can override options of the ratelimit method that would otherwise + fail an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - # Shouldn't raise - limiter.ratelimit( - key=("test_id",), _time_now_s=time_now, rate_hz=10, update=False - ) + # First attempt should be allowed + limiter.ratelimit(key=("test_id",), _time_now_s=0) - allowed, time_allowed = limiter.can_do_action( - key=("test_id",), _time_now_s=time_now, rate_hz=10, - ) - self.assertTrue(allowed) - self.assertEquals(expected_allowed, time_allowed) + # Second attempt, 1s later, will fail + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.assertEqual(context.exception.retry_after_ms, 9000) - # Similarly if we allow a burst of 10 actions, but a rate_hz of 0.1 - time_now = 1 - expected_allowed = 1.0 - limiter.ratelimit( - key=("test_id",), _time_now_s=time_now, burst_count=10, update=False, - ) + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. + limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) - allowed, time_allowed = limiter.can_do_action( - key=("test_id",), _time_now_s=time_now, burst_count=10, - ) - self.assertTrue(allowed) - self.assertEquals(expected_allowed, time_allowed) + # Similarly if we allow a burst of 10 actions + limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) def test_pruning(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) From 08c51148f5e9aa53e6a281c881990c3eb07651e3 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 4 Jun 2020 19:52:56 +0100 Subject: [PATCH 31/31] Remove _ = style --- tests/api/test_ratelimiting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 389b2579f478..d580e729c5eb 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -88,10 +88,10 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): def test_pruning(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - _, _ = limiter.can_do_action(key="test_id_1", _time_now_s=0) + limiter.can_do_action(key="test_id_1", _time_now_s=0) self.assertIn("test_id_1", limiter.actions) - _, _ = limiter.can_do_action(key="test_id_2", _time_now_s=10) + limiter.can_do_action(key="test_id_2", _time_now_s=10) self.assertNotIn("test_id_1", limiter.actions)