Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert some of the media REST code to async/await (#7110)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Mar 20, 2020
1 parent c2db659 commit caec7d4
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 111 deletions.
1 change: 1 addition & 0 deletions changelog.d/7110.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert some of synapse.rest.media to async/await.
110 changes: 49 additions & 61 deletions synapse/rest/media/v1/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import twisted.internet.error
import twisted.web.http
from twisted.internet import defer
from twisted.web.resource import Resource

from synapse.api.errors import (
Expand Down Expand Up @@ -114,15 +113,14 @@ def _start_update_recently_accessed(self):
"update_recently_accessed_media", self._update_recently_accessed
)

@defer.inlineCallbacks
def _update_recently_accessed(self):
async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()

local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()

yield self.store.update_cached_last_access_time(
await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)

Expand All @@ -138,8 +136,7 @@ def mark_recently_accessed(self, server_name, media_id):
else:
self.recently_accessed_locals.add(media_id)

@defer.inlineCallbacks
def create_content(
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
Expand All @@ -158,11 +155,11 @@ def create_content(

file_info = FileInfo(server_name=None, file_id=media_id)

fname = yield self.media_storage.store_file(content, file_info)
fname = await self.media_storage.store_file(content, file_info)

logger.info("Stored local media in file %r", fname)

yield self.store.store_local_media(
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
Expand All @@ -171,12 +168,11 @@ def create_content(
user_id=auth_user,
)

yield self._generate_thumbnails(None, media_id, media_id, media_type)
await self._generate_thumbnails(None, media_id, media_id, media_type)

return "mxc://%s/%s" % (self.server_name, media_id)

@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
Expand All @@ -190,7 +186,7 @@ def get_local_media(self, request, media_id, name):
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
Expand All @@ -204,13 +200,12 @@ def get_local_media(self, request, media_id, name):

file_info = FileInfo(None, media_id, url_cache=url_cache)

responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)

@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
Expand All @@ -236,8 +231,8 @@ def get_remote_media(self, request, server_name, media_id, name):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)

Expand All @@ -246,14 +241,13 @@ def get_remote_media(self, request, server_name, media_id, name):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)

@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
Expand All @@ -274,8 +268,8 @@ def get_remote_media_info(self, server_name, media_id):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)

Expand All @@ -286,8 +280,7 @@ def get_remote_media_info(self, server_name, media_id):

return media_info

@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
Expand All @@ -299,7 +292,7 @@ def _get_remote_media_impl(self, server_name, media_id):
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
media_info = await self.store.get_cached_remote_media(server_name, media_id)

# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
Expand All @@ -317,19 +310,18 @@ def _get_remote_media_impl(self, server_name, media_id):
logger.info("Media is quarantined")
raise NotFoundError()

responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info

# Failed to find the file anywhere, lets download it.

media_info = yield self._download_remote_file(server_name, media_id, file_id)
media_info = await self._download_remote_file(server_name, media_id, file_id)

responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info

@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Expand All @@ -351,7 +343,7 @@ def _download_remote_file(self, server_name, media_id, file_id):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
length, headers = yield self.client.get_file(
length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
Expand Down Expand Up @@ -397,15 +389,15 @@ def _download_remote_file(self, server_name, media_id, file_id):
)
raise SynapseError(502, "Failed to fetch remote media")

yield finish()
await finish()

media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()

logger.info("Stored remote media in file %r", fname)

yield self.store.store_cached_remote_media(
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
Expand All @@ -423,7 +415,7 @@ def _download_remote_file(self, server_name, media_id, file_id):
"filesystem_id": file_id,
}

yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
await self._generate_thumbnails(server_name, media_id, file_id, media_type)

return media_info

Expand Down Expand Up @@ -458,16 +450,15 @@ def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):

return t_byte_source

@defer.inlineCallbacks
def generate_local_exact_thumbnail(
async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)

thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
Expand All @@ -490,7 +481,7 @@ def generate_local_exact_thumbnail(
thumbnail_type=t_type,
)

output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
Expand All @@ -500,22 +491,21 @@ def generate_local_exact_thumbnail(

t_len = os.path.getsize(output_path)

yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)

return output_path

@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)

thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
Expand All @@ -537,7 +527,7 @@ def generate_remote_exact_thumbnail(
thumbnail_type=t_type,
)

output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
Expand All @@ -547,7 +537,7 @@ def generate_remote_exact_thumbnail(

t_len = os.path.getsize(output_path)

yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
Expand All @@ -560,8 +550,7 @@ def generate_remote_exact_thumbnail(

return output_path

@defer.inlineCallbacks
def _generate_thumbnails(
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
Expand All @@ -582,7 +571,7 @@ def _generate_thumbnails(
if not requirements:
return

input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)

Expand All @@ -600,7 +589,7 @@ def _generate_thumbnails(
return

if thumbnailer.transpose_method is not None:
m_width, m_height = yield defer_to_thread(
m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)

Expand All @@ -620,11 +609,11 @@ def _generate_thumbnails(
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
Expand All @@ -646,7 +635,7 @@ def _generate_thumbnails(
url_cache=url_cache,
)

output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
Expand All @@ -656,7 +645,7 @@ def _generate_thumbnails(

# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
Expand All @@ -667,15 +656,14 @@ def _generate_thumbnails(
t_len,
)
else:
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)

return {"width": m_width, "height": m_height}

@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
async def delete_old_remote_media(self, before_ts):
old_media = await self.store.get_remote_media_before(before_ts)

deleted = 0

Expand All @@ -689,7 +677,7 @@ def delete_old_remote_media(self, before_ts):

# TODO: Should we delete from the backup store

with (yield self.remote_media_linearizer.queue(key)):
with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
Expand All @@ -705,7 +693,7 @@ def delete_old_remote_media(self, before_ts):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)

yield self.store.delete_remote_media(origin, media_id)
await self.store.delete_remote_media(origin, media_id)
deleted += 1

return {"deleted": deleted}
Expand Down
Loading

0 comments on commit caec7d4

Please sign in to comment.