Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Use Concatenate to annotate do_execute
Browse files Browse the repository at this point in the history
I'm not sure this gives us a huge amount of type safety, see this
comment:
#12312 (comment)

In any case, it's a nice bit of practice with `ParamSpec`.
  • Loading branch information
David Robertson committed May 7, 2022
1 parent 0ce2201 commit 8554128
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ netaddr = ">=0.7.18"
# add a lower bound to the Jinja2 dependency.
Jinja2 = ">=3.0"
bleach = ">=1.4.3"
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
# We use `ParamSpec` and `Concatenate`, which were added in `typing-extensions` 3.10.0.0.
typing-extensions = ">=3.10.0"
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
Expand Down
19 changes: 14 additions & 5 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

import attr
from prometheus_client import Histogram
from typing_extensions import Literal
from typing_extensions import Concatenate, Literal, ParamSpec

from twisted.enterprise import adbapi

Expand Down Expand Up @@ -194,7 +194,7 @@ def __getattr__(self, name):
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]


P = ParamSpec("P")
R = TypeVar("R")


Expand Down Expand Up @@ -339,7 +339,13 @@ def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())

def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
def _do_execute(
self,
func: Callable[Concatenate[str, P], R],
sql: str,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
sql = self._make_sql_one_line(sql)

# TODO(paul): Maybe use 'info' and 'debug' for values?
Expand All @@ -348,7 +354,10 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self.database_engine.convert_param_style(sql)
if args:
try:
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
# The type-ignore should be redundant once mypy releases a version with
# https:/python/mypy/pull/12668. (`args` might be empty,
# (but we'll catch the index error if so.)
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index]
except Exception:
# Don't let logging failures stop SQL from working
pass
Expand All @@ -363,7 +372,7 @@ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
opentracing.tags.DATABASE_STATEMENT: sql,
},
):
return func(sql, *args)
return func(sql, *args, **kwargs)
except Exception as e:
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
Expand Down

0 comments on commit 8554128

Please sign in to comment.